Add bionic support to endian.h am: 0b3ea84b83 am: c7689501a3
am: f128c25024 -s ours
Change-Id: Id8edfaa945afaa5e87489d914aa481b35f50d88f
diff --git a/Android.mk b/Android.mk
index 2d19e3d..88aad2b 100644
--- a/Android.mk
+++ b/Android.mk
@@ -36,27 +36,25 @@
MY_LIBTEXTCLASSIFIER_CFLAGS := \
$(MY_LIBTEXTCLASSIFIER_WARNING_CFLAGS) \
- -fvisibility=hidden
+ -fvisibility=hidden \
+ -DLIBTEXTCLASSIFIER_UNILIB_ICU
# Only enable debug logging in userdebug/eng builds.
ifneq (,$(filter userdebug eng, $(TARGET_BUILD_VARIANT)))
MY_LIBTEXTCLASSIFIER_CFLAGS += -DTC_DEBUG_LOGGING=1
endif
-# ------------------------
-# libtextclassifier_protos
-# ------------------------
+# -----------------
+# flatbuffers
+# -----------------
+# Empty static library so that other projects can include just the basic
+# FlatBuffers headers as a module.
include $(CLEAR_VARS)
-
-LOCAL_MODULE := libtextclassifier_protos
-
-LOCAL_STRIP_MODULE := $(LIBTEXTCLASSIFIER_STRIP_OPTS)
-
-LOCAL_SRC_FILES := $(call all-proto-files-under, .)
-LOCAL_SHARED_LIBRARIES := libprotobuf-cpp-lite
-
-LOCAL_CFLAGS := $(MY_LIBTEXTCLASSIFIER_WARNING_CFLAGS)
+LOCAL_MODULE := flatbuffers
+LOCAL_EXPORT_C_INCLUDES := $(LOCAL_PATH)/include
+LOCAL_EXPORT_CPPFLAGS := -std=c++11 -fexceptions -Wall \
+ -DFLATBUFFERS_TRACK_VERIFIER_BUFFER_SIZE
include $(BUILD_STATIC_LIBRARY)
@@ -67,20 +65,18 @@
include $(CLEAR_VARS)
LOCAL_MODULE := libtextclassifier
-proto_sources_dir := $(generated_sources_dir)
-
LOCAL_CPP_EXTENSION := .cc
LOCAL_CFLAGS += $(MY_LIBTEXTCLASSIFIER_CFLAGS)
LOCAL_STRIP_MODULE := $(LIBTEXTCLASSIFIER_STRIP_OPTS)
LOCAL_SRC_FILES := $(filter-out tests/% %_test.cc,$(call all-subdir-cpp-files))
-LOCAL_C_INCLUDES += $(proto_sources_dir)/proto/external/libtextclassifier
+LOCAL_C_INCLUDES := $(TOP)/external/tensorflow $(TOP)/external/flatbuffers/include
-LOCAL_STATIC_LIBRARIES += libtextclassifier_protos
-LOCAL_SHARED_LIBRARIES += libprotobuf-cpp-lite
LOCAL_SHARED_LIBRARIES += liblog
LOCAL_SHARED_LIBRARIES += libicuuc libicui18n
-LOCAL_REQUIRED_MODULES := textclassifier.smartselection.en.model
+LOCAL_SHARED_LIBRARIES += libtflite
+LOCAL_STATIC_LIBRARIES += flatbuffers
+LOCAL_REQUIRED_MODULES := textclassifier.en.model
LOCAL_ADDITIONAL_DEPENDENCIES += $(LOCAL_PATH)/jni.lds
LOCAL_LDFLAGS += -Wl,-version-script=$(LOCAL_PATH)/jni.lds
@@ -101,162 +97,30 @@
LOCAL_CFLAGS += $(MY_LIBTEXTCLASSIFIER_CFLAGS)
LOCAL_STRIP_MODULE := $(LIBTEXTCLASSIFIER_STRIP_OPTS)
-LOCAL_TEST_DATA := $(call find-test-data-in-subdirs, $(LOCAL_PATH), *, tests/testdata)
+LOCAL_TEST_DATA := $(call find-test-data-in-subdirs, $(LOCAL_PATH), *, test_data)
-LOCAL_CPPFLAGS_32 += -DTEST_DATA_DIR="\"/data/nativetest/libtextclassifier_tests/tests/testdata/\""
-LOCAL_CPPFLAGS_64 += -DTEST_DATA_DIR="\"/data/nativetest64/libtextclassifier_tests/tests/testdata/\""
+LOCAL_CPPFLAGS_32 += -DLIBTEXTCLASSIFIER_TEST_DATA_DIR="\"/data/nativetest/libtextclassifier_tests/test_data/\""
+LOCAL_CPPFLAGS_64 += -DLIBTEXTCLASSIFIER_TEST_DATA_DIR="\"/data/nativetest64/libtextclassifier_tests/test_data/\""
LOCAL_SRC_FILES := $(call all-subdir-cpp-files)
-LOCAL_C_INCLUDES += $(proto_sources_dir)/proto/external/libtextclassifier
+LOCAL_C_INCLUDES := $(TOP)/external/tensorflow $(TOP)/external/flatbuffers/include
-LOCAL_STATIC_LIBRARIES += libtextclassifier_protos libgmock
-LOCAL_SHARED_LIBRARIES += libprotobuf-cpp-lite
+LOCAL_STATIC_LIBRARIES += libgmock
LOCAL_SHARED_LIBRARIES += liblog
LOCAL_SHARED_LIBRARIES += libicuuc libicui18n
+LOCAL_SHARED_LIBRARIES += libtflite
include $(BUILD_NATIVE_TEST)
-# ------------
-# LangId model
-# ------------
-
-include $(CLEAR_VARS)
-LOCAL_MODULE := textclassifier.langid.model
-LOCAL_MODULE_CLASS := ETC
-LOCAL_MODULE_OWNER := google
-LOCAL_SRC_FILES := ./models/textclassifier.langid.model
-LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier
-include $(BUILD_PREBUILT)
-
# ----------------------
# Smart Selection models
# ----------------------
include $(CLEAR_VARS)
-LOCAL_MODULE := textclassifier.smartselection.ar.model
+LOCAL_MODULE := textclassifier.en.model
LOCAL_MODULE_CLASS := ETC
LOCAL_MODULE_OWNER := google
-LOCAL_SRC_FILES := ./models/textclassifier.smartselection.ar.model
-LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier
-include $(BUILD_PREBUILT)
-
-include $(CLEAR_VARS)
-LOCAL_MODULE := textclassifier.smartselection.de.model
-LOCAL_MODULE_CLASS := ETC
-LOCAL_MODULE_OWNER := google
-LOCAL_SRC_FILES := ./models/textclassifier.smartselection.de.model
-LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier
-include $(BUILD_PREBUILT)
-
-include $(CLEAR_VARS)
-LOCAL_MODULE := textclassifier.smartselection.en.model
-LOCAL_MODULE_CLASS := ETC
-LOCAL_MODULE_OWNER := google
-LOCAL_SRC_FILES := ./models/textclassifier.smartselection.en.model
-LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier
-include $(BUILD_PREBUILT)
-
-include $(CLEAR_VARS)
-LOCAL_MODULE := textclassifier.smartselection.es.model
-LOCAL_MODULE_CLASS := ETC
-LOCAL_MODULE_OWNER := google
-LOCAL_SRC_FILES := ./models/textclassifier.smartselection.es.model
-LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier
-include $(BUILD_PREBUILT)
-
-include $(CLEAR_VARS)
-LOCAL_MODULE := textclassifier.smartselection.fr.model
-LOCAL_MODULE_CLASS := ETC
-LOCAL_MODULE_OWNER := google
-LOCAL_SRC_FILES := ./models/textclassifier.smartselection.fr.model
-LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier
-include $(BUILD_PREBUILT)
-
-include $(CLEAR_VARS)
-LOCAL_MODULE := textclassifier.smartselection.it.model
-LOCAL_MODULE_CLASS := ETC
-LOCAL_MODULE_OWNER := google
-LOCAL_SRC_FILES := ./models/textclassifier.smartselection.it.model
-LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier
-include $(BUILD_PREBUILT)
-
-include $(CLEAR_VARS)
-LOCAL_MODULE := textclassifier.smartselection.ja.model
-LOCAL_MODULE_CLASS := ETC
-LOCAL_MODULE_OWNER := google
-LOCAL_SRC_FILES := ./models/textclassifier.smartselection.ja.model
-LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier
-include $(BUILD_PREBUILT)
-
-include $(CLEAR_VARS)
-LOCAL_MODULE := textclassifier.smartselection.ko.model
-LOCAL_MODULE_CLASS := ETC
-LOCAL_MODULE_OWNER := google
-LOCAL_SRC_FILES := ./models/textclassifier.smartselection.ko.model
-LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier
-include $(BUILD_PREBUILT)
-
-include $(CLEAR_VARS)
-LOCAL_MODULE := textclassifier.smartselection.nl.model
-LOCAL_MODULE_CLASS := ETC
-LOCAL_MODULE_OWNER := google
-LOCAL_SRC_FILES := ./models/textclassifier.smartselection.nl.model
-LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier
-include $(BUILD_PREBUILT)
-
-include $(CLEAR_VARS)
-LOCAL_MODULE := textclassifier.smartselection.pl.model
-LOCAL_MODULE_CLASS := ETC
-LOCAL_MODULE_OWNER := google
-LOCAL_SRC_FILES := ./models/textclassifier.smartselection.pl.model
-LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier
-include $(BUILD_PREBUILT)
-
-include $(CLEAR_VARS)
-LOCAL_MODULE := textclassifier.smartselection.pt.model
-LOCAL_MODULE_CLASS := ETC
-LOCAL_MODULE_OWNER := google
-LOCAL_SRC_FILES := ./models/textclassifier.smartselection.pt.model
-LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier
-include $(BUILD_PREBUILT)
-
-include $(CLEAR_VARS)
-LOCAL_MODULE := textclassifier.smartselection.ru.model
-LOCAL_MODULE_CLASS := ETC
-LOCAL_MODULE_OWNER := google
-LOCAL_SRC_FILES := ./models/textclassifier.smartselection.ru.model
-LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier
-include $(BUILD_PREBUILT)
-
-include $(CLEAR_VARS)
-LOCAL_MODULE := textclassifier.smartselection.th.model
-LOCAL_MODULE_CLASS := ETC
-LOCAL_MODULE_OWNER := google
-LOCAL_SRC_FILES := ./models/textclassifier.smartselection.th.model
-LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier
-include $(BUILD_PREBUILT)
-
-include $(CLEAR_VARS)
-LOCAL_MODULE := textclassifier.smartselection.tr.model
-LOCAL_MODULE_CLASS := ETC
-LOCAL_MODULE_OWNER := google
-LOCAL_SRC_FILES := ./models/textclassifier.smartselection.tr.model
-LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier
-include $(BUILD_PREBUILT)
-
-include $(CLEAR_VARS)
-LOCAL_MODULE := textclassifier.smartselection.zh-Hant.model
-LOCAL_MODULE_CLASS := ETC
-LOCAL_MODULE_OWNER := google
-LOCAL_SRC_FILES := ./models/textclassifier.smartselection.zh-Hant.model
-LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier
-include $(BUILD_PREBUILT)
-
-include $(CLEAR_VARS)
-LOCAL_MODULE := textclassifier.smartselection.zh.model
-LOCAL_MODULE_CLASS := ETC
-LOCAL_MODULE_OWNER := google
-LOCAL_SRC_FILES := ./models/textclassifier.smartselection.zh.model
+LOCAL_SRC_FILES := ./models/textclassifier.en.model
LOCAL_MODULE_PATH := $(TARGET_OUT_ETC)/textclassifier
include $(BUILD_PREBUILT)
@@ -265,10 +129,7 @@
# -----------------------
include $(CLEAR_VARS)
-LOCAL_MODULE := textclassifier.smartselection.bundle1
-LOCAL_REQUIRED_MODULES := textclassifier.smartselection.en.model
-LOCAL_REQUIRED_MODULES += textclassifier.smartselection.es.model
-LOCAL_REQUIRED_MODULES += textclassifier.smartselection.de.model
-LOCAL_REQUIRED_MODULES += textclassifier.smartselection.fr.model
+LOCAL_MODULE := textclassifier.bundle1
+LOCAL_REQUIRED_MODULES := textclassifier.en.model
LOCAL_CFLAGS := $(MY_LIBTEXTCLASSIFIER_WARNING_CFLAGS)
include $(BUILD_STATIC_LIBRARY)
diff --git a/cached-features.cc b/cached-features.cc
new file mode 100644
index 0000000..0b22d6d
--- /dev/null
+++ b/cached-features.cc
@@ -0,0 +1,177 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "cached-features.h"
+
+#include "tensor-view.h"
+#include "util/base/logging.h"
+
+namespace libtextclassifier2 {
+
+namespace {
+
+// Populates the features for one token into the target vector at an offset
+// corresponding to the given token index. It builds the features to populate by
+// embedding the sparse features and combining them with the dense featues.
+// Embeds sparse features and the features of one token into the features
+// vector.
+bool PopulateTokenFeatures(int target_feature_index,
+ const std::vector<int>& sparse_features,
+ const std::vector<float>& dense_features,
+ int feature_vector_size,
+ EmbeddingExecutor* embedding_executor,
+ std::vector<float>* target_features) {
+ const int sparse_embedding_size = feature_vector_size - dense_features.size();
+ float* dest =
+ target_features->data() + target_feature_index * feature_vector_size;
+
+ // Embed sparse features.
+ if (!embedding_executor->AddEmbedding(
+ TensorView<int>(sparse_features.data(),
+ {static_cast<int>(sparse_features.size())}),
+ dest, sparse_embedding_size)) {
+ return false;
+ }
+
+ // Copy dense features.
+ for (int j = 0; j < dense_features.size(); ++j) {
+ dest[sparse_embedding_size + j] = dense_features[j];
+ }
+
+ return true;
+}
+
+} // namespace
+
+CachedFeatures::CachedFeatures(
+ const TokenSpan& extraction_span,
+ const std::vector<std::vector<int>>& sparse_features,
+ const std::vector<std::vector<float>>& dense_features,
+ const std::vector<int>& padding_sparse_features,
+ const std::vector<float>& padding_dense_features,
+ const FeatureProcessorOptions_::BoundsSensitiveFeatures* config,
+ EmbeddingExecutor* embedding_executor, int feature_vector_size)
+ : extraction_span_(extraction_span), config_(config) {
+ int num_extracted_tokens = 0;
+ num_extracted_tokens += config->num_tokens_before();
+ num_extracted_tokens += config->num_tokens_inside_left();
+ num_extracted_tokens += config->num_tokens_inside_right();
+ num_extracted_tokens += config->num_tokens_after();
+ if (config->include_inside_bag()) {
+ ++num_extracted_tokens;
+ }
+ output_features_size_ = num_extracted_tokens * feature_vector_size;
+ if (config->include_inside_length()) {
+ ++output_features_size_;
+ }
+
+ features_.resize(feature_vector_size * TokenSpanSize(extraction_span));
+ for (int i = 0; i < TokenSpanSize(extraction_span); ++i) {
+ if (!PopulateTokenFeatures(/*target_feature_index=*/i, sparse_features[i],
+ dense_features[i], feature_vector_size,
+ embedding_executor, &features_)) {
+ TC_LOG(ERROR) << "Could not embed sparse token features.";
+ return;
+ }
+ }
+
+ padding_features_.resize(feature_vector_size);
+ if (!PopulateTokenFeatures(/*target_feature_index=*/0,
+ padding_sparse_features, padding_dense_features,
+ feature_vector_size, embedding_executor,
+ &padding_features_)) {
+ TC_LOG(ERROR) << "Could not embed sparse padding token features.";
+ return;
+ }
+}
+
+std::vector<float> CachedFeatures::Get(TokenSpan selected_span) const {
+ selected_span.first -= extraction_span_.first;
+ selected_span.second -= extraction_span_.first;
+
+ std::vector<float> output_features;
+ output_features.reserve(output_features_size_);
+
+ // Append the features for tokens around the left bound. Masks out tokens
+ // after the right bound, so that if num_tokens_inside_left goes past it,
+ // padding tokens will be used.
+ AppendFeatures(
+ /*intended_span=*/{selected_span.first - config_->num_tokens_before(),
+ selected_span.first +
+ config_->num_tokens_inside_left()},
+ /*read_mask_span=*/{0, selected_span.second}, &output_features);
+
+ // Append the features for tokens around the right bound. Masks out tokens
+ // before the left bound, so that if num_tokens_inside_right goes past it,
+ // padding tokens will be used.
+ AppendFeatures(
+ /*intended_span=*/{selected_span.second -
+ config_->num_tokens_inside_right(),
+ selected_span.second + config_->num_tokens_after()},
+ /*read_mask_span=*/{selected_span.first, TokenSpanSize(extraction_span_)},
+ &output_features);
+
+ if (config_->include_inside_bag()) {
+ AppendSummedFeatures(selected_span, &output_features);
+ }
+
+ if (config_->include_inside_length()) {
+ output_features.push_back(static_cast<float>(TokenSpanSize(selected_span)));
+ }
+
+ return output_features;
+}
+
+void CachedFeatures::AppendFeatures(const TokenSpan& intended_span,
+ const TokenSpan& read_mask_span,
+ std::vector<float>* output_features) const {
+ const TokenSpan copy_span =
+ IntersectTokenSpans(intended_span, read_mask_span);
+ for (int i = intended_span.first; i < copy_span.first; ++i) {
+ AppendPaddingFeatures(output_features);
+ }
+ output_features->insert(
+ output_features->end(),
+ features_.begin() + copy_span.first * NumFeaturesPerToken(),
+ features_.begin() + copy_span.second * NumFeaturesPerToken());
+ for (int i = copy_span.second; i < intended_span.second; ++i) {
+ AppendPaddingFeatures(output_features);
+ }
+}
+
+void CachedFeatures::AppendPaddingFeatures(
+ std::vector<float>* output_features) const {
+ output_features->insert(output_features->end(), padding_features_.begin(),
+ padding_features_.end());
+}
+
+void CachedFeatures::AppendSummedFeatures(
+ const TokenSpan& summing_span, std::vector<float>* output_features) const {
+ const int offset = output_features->size();
+ output_features->resize(output_features->size() + NumFeaturesPerToken());
+ for (int i = summing_span.first; i < summing_span.second; ++i) {
+ for (int j = 0; j < NumFeaturesPerToken(); ++j) {
+ (*output_features)[offset + j] +=
+ features_[i * NumFeaturesPerToken() + j];
+ }
+ }
+}
+
+int CachedFeatures::NumFeaturesPerToken() const {
+ return padding_features_.size();
+}
+
+} // namespace libtextclassifier2
diff --git a/cached-features.h b/cached-features.h
new file mode 100644
index 0000000..5ffb9a9
--- /dev/null
+++ b/cached-features.h
@@ -0,0 +1,74 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_CACHED_FEATURES_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_CACHED_FEATURES_H_
+
+#include <memory>
+#include <vector>
+
+#include "model-executor.h"
+#include "model_generated.h"
+#include "types.h"
+
+namespace libtextclassifier2 {
+
+// Holds state for extracting features across multiple calls and reusing them.
+// Assumes that features for each Token are independent.
+class CachedFeatures {
+ public:
+ CachedFeatures(
+ const TokenSpan& extraction_span,
+ const std::vector<std::vector<int>>& sparse_features,
+ const std::vector<std::vector<float>>& dense_features,
+ const std::vector<int>& padding_sparse_features,
+ const std::vector<float>& padding_dense_features,
+ const FeatureProcessorOptions_::BoundsSensitiveFeatures* config,
+ EmbeddingExecutor* embedding_executor, int feature_vector_size);
+
+ // Gets a vector of features for the given token span.
+ std::vector<float> Get(TokenSpan selected_span) const;
+
+ private:
+ // Appends token features to the output. The intended_span specifies which
+ // tokens' features should be used in principle. The read_mask_span restricts
+ // which tokens are actually read. For tokens outside of the read_mask_span,
+ // padding tokens are used instead.
+ void AppendFeatures(const TokenSpan& intended_span,
+ const TokenSpan& read_mask_span,
+ std::vector<float>* output_features) const;
+
+ // Appends features of one padding token to the output.
+ void AppendPaddingFeatures(std::vector<float>* output_features) const;
+
+ // Appends the features of tokens from the given span to the output. The
+ // features are summed so that the appended features have the size
+ // corresponding to one token.
+ void AppendSummedFeatures(const TokenSpan& summing_span,
+ std::vector<float>* output_features) const;
+
+ int NumFeaturesPerToken() const;
+
+ const TokenSpan extraction_span_;
+ const FeatureProcessorOptions_::BoundsSensitiveFeatures* config_;
+ int output_features_size_;
+ std::vector<float> features_;
+ std::vector<float> padding_features_;
+};
+
+} // namespace libtextclassifier2
+
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_CACHED_FEATURES_H_
diff --git a/cached-features_test.cc b/cached-features_test.cc
new file mode 100644
index 0000000..2412ff3
--- /dev/null
+++ b/cached-features_test.cc
@@ -0,0 +1,121 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "cached-features.h"
+
+#include "model-executor.h"
+#include "tensor-view.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using testing::ElementsAreArray;
+using testing::FloatEq;
+using testing::Matcher;
+
+namespace libtextclassifier2 {
+namespace {
+
+Matcher<std::vector<float>> ElementsAreFloat(const std::vector<float>& values) {
+ std::vector<Matcher<float>> matchers;
+ for (const float value : values) {
+ matchers.push_back(FloatEq(value));
+ }
+ return ElementsAreArray(matchers);
+}
+
+// EmbeddingExecutor that always returns features based on
+class FakeEmbeddingExecutor : public EmbeddingExecutor {
+ public:
+ bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
+ int dest_size) override {
+ TC_CHECK_GE(dest_size, 2);
+ EXPECT_EQ(sparse_features.size(), 1);
+
+ dest[0] = sparse_features.data()[0] * 11.0f;
+ dest[1] = -sparse_features.data()[0] * 11.0f;
+ return true;
+ }
+
+ private:
+ std::vector<float> storage_;
+};
+
+TEST(CachedFeaturesTest, Simple) {
+ FeatureProcessorOptions_::BoundsSensitiveFeaturesT config;
+ config.enabled = true;
+ config.num_tokens_before = 2;
+ config.num_tokens_inside_left = 2;
+ config.num_tokens_inside_right = 2;
+ config.num_tokens_after = 2;
+ config.include_inside_bag = true;
+ config.include_inside_length = true;
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(CreateBoundsSensitiveFeatures(builder, &config));
+ flatbuffers::DetachedBuffer config_fb = builder.Release();
+
+ std::vector<std::vector<int>> sparse_features(9);
+ for (int i = 0; i < sparse_features.size(); ++i) {
+ sparse_features[i].push_back(i + 1);
+ }
+ std::vector<std::vector<float>> dense_features(9);
+ for (int i = 0; i < dense_features.size(); ++i) {
+ dense_features[i].push_back((i + 1) * 0.1);
+ }
+
+ std::vector<int> padding_sparse_features = {10203};
+ std::vector<float> padding_dense_features = {321.0};
+
+ FakeEmbeddingExecutor executor;
+ const CachedFeatures cached_features(
+ {3, 9}, sparse_features, dense_features, padding_sparse_features,
+ padding_dense_features,
+ flatbuffers::GetRoot<FeatureProcessorOptions_::BoundsSensitiveFeatures>(
+ config_fb.data()),
+ &executor, /*feature_vector_size=*/3);
+
+ EXPECT_THAT(cached_features.Get({5, 8}),
+ ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2,
+ 33.0, -33.0, 0.3, 44.0, -44.0, 0.4,
+ 44.0, -44.0, 0.4, 55.0, -55.0, 0.5,
+ 66.0, -66.0, 0.6, 112233.0, -112233.0, 321.0,
+ 132.0, -132.0, 1.2, 3.0}));
+
+ EXPECT_THAT(
+ cached_features.Get({5, 7}),
+ ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0,
+ -33.0, 0.3, 44.0, -44.0, 0.4, 33.0, -33.0,
+ 0.3, 44.0, -44.0, 0.4, 55.0, -55.0, 0.5,
+ 66.0, -66.0, 0.6, 77.0, -77.0, 0.7, 2.0}));
+
+ EXPECT_THAT(
+ cached_features.Get({6, 8}),
+ ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3, 44.0,
+ -44.0, 0.4, 55.0, -55.0, 0.5, 44.0, -44.0,
+ 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6,
+ 112233.0, -112233.0, 321.0, 99.0, -99.0, 0.9, 2.0}));
+
+ EXPECT_THAT(
+ cached_features.Get({6, 7}),
+ ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3,
+ 44.0, -44.0, 0.4, 112233.0, -112233.0, 321.0,
+ 112233.0, -112233.0, 321.0, 44.0, -44.0, 0.4,
+ 55.0, -55.0, 0.5, 66.0, -66.0, 0.6,
+ 44.0, -44.0, 0.4, 1.0}));
+}
+
+} // namespace
+} // namespace libtextclassifier2
diff --git a/common/algorithm.h b/common/algorithm.h
deleted file mode 100644
index 365eec9..0000000
--- a/common/algorithm.h
+++ /dev/null
@@ -1,49 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Generic utils similar to those from the C++ header <algorithm>.
-
-#ifndef LIBTEXTCLASSIFIER_COMMON_ALGORITHM_H_
-#define LIBTEXTCLASSIFIER_COMMON_ALGORITHM_H_
-
-#include <algorithm>
-#include <vector>
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-// Returns index of max element from the vector |elements|. Returns 0 if
-// |elements| is empty. T should be a type that can be compared by operator<.
-template<typename T>
-inline int GetArgMax(const std::vector<T> &elements) {
- return std::distance(
- elements.begin(),
- std::max_element(elements.begin(), elements.end()));
-}
-
-// Returns index of min element from the vector |elements|. Returns 0 if
-// |elements| is empty. T should be a type that can be compared by operator<.
-template<typename T>
-inline int GetArgMin(const std::vector<T> &elements) {
- return std::distance(
- elements.begin(),
- std::min_element(elements.begin(), elements.end()));
-}
-
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_COMMON_ALGORITHM_H_
diff --git a/common/embedding-feature-extractor.cc b/common/embedding-feature-extractor.cc
deleted file mode 100644
index 254af45..0000000
--- a/common/embedding-feature-extractor.cc
+++ /dev/null
@@ -1,68 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "common/embedding-feature-extractor.h"
-
-#include <stddef.h>
-
-#include <vector>
-
-#include "common/feature-extractor.h"
-#include "common/feature-types.h"
-#include "common/task-context.h"
-#include "util/base/integral_types.h"
-#include "util/base/logging.h"
-#include "util/strings/numbers.h"
-#include "util/strings/split.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-bool GenericEmbeddingFeatureExtractor::Init(TaskContext *context) {
- // Don't use version to determine how to get feature FML.
- const std::string features = context->Get(GetParamName("features"), "");
- TC_LOG(INFO) << "Features: " << features;
-
- const std::string embedding_names =
- context->Get(GetParamName("embedding_names"), "");
- TC_LOG(INFO) << "Embedding names: " << embedding_names;
-
- const std::string embedding_dims =
- context->Get(GetParamName("embedding_dims"), "");
- TC_LOG(INFO) << "Embedding dims: " << embedding_dims;
-
- embedding_fml_ = strings::Split(features, ';');
- embedding_names_ = strings::Split(embedding_names, ';');
- for (const std::string &dim : strings::Split(embedding_dims, ';')) {
- int32 parsed_dim = 0;
- if (!ParseInt32(dim.c_str(), &parsed_dim)) {
- TC_LOG(ERROR) << "Unable to parse dim " << dim;
- return false;
- }
- embedding_dims_.push_back(parsed_dim);
- }
- if ((embedding_fml_.size() != embedding_names_.size()) ||
- (embedding_fml_.size() != embedding_dims_.size())) {
- TC_LOG(ERROR) << "Mismatch: #fml specs = " << embedding_fml_.size()
- << "; #names = " << embedding_names_.size()
- << "; #dims = " << embedding_dims_.size();
- return false;
- }
- return true;
-}
-
-} // namespace nlp_core
-} // namespace libtextclassifier
diff --git a/common/embedding-feature-extractor.h b/common/embedding-feature-extractor.h
deleted file mode 100644
index 0efd0d2..0000000
--- a/common/embedding-feature-extractor.h
+++ /dev/null
@@ -1,197 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_COMMON_EMBEDDING_FEATURE_EXTRACTOR_H_
-#define LIBTEXTCLASSIFIER_COMMON_EMBEDDING_FEATURE_EXTRACTOR_H_
-
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "common/feature-extractor.h"
-#include "common/task-context.h"
-#include "common/workspace.h"
-#include "util/base/logging.h"
-#include "util/base/macros.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-// An EmbeddingFeatureExtractor manages the extraction of features for
-// embedding-based models. It wraps a sequence of underlying classes of feature
-// extractors, along with associated predicate maps. Each class of feature
-// extractors is associated with a name, e.g., "words", "labels", "tags".
-//
-// The class is split between a generic abstract version,
-// GenericEmbeddingFeatureExtractor (that can be initialized without knowing the
-// signature of the ExtractFeatures method) and a typed version.
-//
-// The predicate maps must be initialized before use: they can be loaded using
-// Read() or updated via UpdateMapsForExample.
-class GenericEmbeddingFeatureExtractor {
- public:
- GenericEmbeddingFeatureExtractor() {}
- virtual ~GenericEmbeddingFeatureExtractor() {}
-
- // Get the prefix std::string to put in front of all arguments, so they don't
- // conflict with other embedding models.
- virtual const std::string ArgPrefix() const = 0;
-
- // Initializes predicate maps and embedding space names that are common for
- // all embedding-based feature extractors.
- virtual bool Init(TaskContext *context);
-
- // Requests workspace for the underlying feature extractors. This is
- // implemented in the typed class.
- virtual void RequestWorkspaces(WorkspaceRegistry *registry) = 0;
-
- // Returns number of embedding spaces.
- int NumEmbeddings() const { return embedding_dims_.size(); }
-
- // Number of predicates for the embedding at a given index (vocabulary size).
- // Returns -1 if index is out of bounds.
- int EmbeddingSize(int index) const {
- const GenericFeatureExtractor *extractor = generic_feature_extractor(index);
- return (extractor == nullptr) ? -1 : extractor->GetDomainSize();
- }
-
- // Returns the dimensionality of the embedding space.
- int EmbeddingDims(int index) const { return embedding_dims_[index]; }
-
- // Accessor for embedding dims (dimensions of the embedding spaces).
- const std::vector<int> &embedding_dims() const { return embedding_dims_; }
-
- const std::vector<std::string> &embedding_fml() const {
- return embedding_fml_;
- }
-
- // Get parameter name by concatenating the prefix and the original name.
- std::string GetParamName(const std::string ¶m_name) const {
- std::string full_name = ArgPrefix();
- full_name.push_back('_');
- full_name.append(param_name);
- return full_name;
- }
-
- protected:
- // Provides the generic class with access to the templated extractors. This is
- // used to get the type information out of the feature extractor without
- // knowing the specific calling arguments of the extractor itself.
- // Returns nullptr for an out-of-bounds idx.
- virtual const GenericFeatureExtractor *generic_feature_extractor(
- int idx) const = 0;
-
- private:
- // Embedding space names for parameter sharing.
- std::vector<std::string> embedding_names_;
-
- // FML strings for each feature extractor.
- std::vector<std::string> embedding_fml_;
-
- // Size of each of the embedding spaces (maximum predicate id).
- std::vector<int> embedding_sizes_;
-
- // Embedding dimensions of the embedding spaces (i.e. 32, 64 etc.)
- std::vector<int> embedding_dims_;
-
- TC_DISALLOW_COPY_AND_ASSIGN(GenericEmbeddingFeatureExtractor);
-};
-
-// Templated, object-specific implementation of the
-// EmbeddingFeatureExtractor. EXTRACTOR should be a FeatureExtractor<OBJ,
-// ARGS...> class that has the appropriate FeatureTraits() to ensure that
-// locator type features work.
-//
-// Note: for backwards compatibility purposes, this always reads the FML spec
-// from "<prefix>_features".
-template <class EXTRACTOR, class OBJ, class... ARGS>
-class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor {
- public:
- // Initializes all predicate maps, feature extractors, etc.
- bool Init(TaskContext *context) override {
- if (!GenericEmbeddingFeatureExtractor::Init(context)) {
- return false;
- }
- feature_extractors_.resize(embedding_fml().size());
- for (int i = 0; i < embedding_fml().size(); ++i) {
- feature_extractors_[i].reset(new EXTRACTOR());
- if (!feature_extractors_[i]->Parse(embedding_fml()[i])) {
- return false;
- }
- if (!feature_extractors_[i]->Setup(context)) {
- return false;
- }
- }
- for (auto &feature_extractor : feature_extractors_) {
- if (!feature_extractor->Init(context)) {
- return false;
- }
- }
- return true;
- }
-
- // Requests workspaces from the registry. Must be called after Init(), and
- // before Preprocess().
- void RequestWorkspaces(WorkspaceRegistry *registry) override {
- for (auto &feature_extractor : feature_extractors_) {
- feature_extractor->RequestWorkspaces(registry);
- }
- }
-
- // Must be called on the object one state for each sentence, before any
- // feature extraction (e.g., UpdateMapsForExample, ExtractFeatures).
- void Preprocess(WorkspaceSet *workspaces, OBJ *obj) const {
- for (auto &feature_extractor : feature_extractors_) {
- feature_extractor->Preprocess(workspaces, obj);
- }
- }
-
- // Extracts features using the extractors. Note that features must already
- // be initialized to the correct number of feature extractors. No predicate
- // mapping is applied.
- void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &obj,
- ARGS... args,
- std::vector<FeatureVector> *features) const {
- TC_DCHECK(features != nullptr);
- TC_DCHECK_EQ(features->size(), feature_extractors_.size());
- for (int i = 0; i < feature_extractors_.size(); ++i) {
- (*features)[i].clear();
- feature_extractors_[i]->ExtractFeatures(workspaces, obj, args...,
- &(*features)[i]);
- }
- }
-
- protected:
- // Provides generic access to the feature extractors.
- const GenericFeatureExtractor *generic_feature_extractor(
- int idx) const override {
- if ((idx < 0) || (idx >= feature_extractors_.size())) {
- TC_LOG(ERROR) << "Out of bounds index " << idx;
- TC_DCHECK(false); // Crash in debug mode.
- return nullptr;
- }
- return feature_extractors_[idx].get();
- }
-
- private:
- // Templated feature extractor class.
- std::vector<std::unique_ptr<EXTRACTOR>> feature_extractors_;
-};
-
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_COMMON_EMBEDDING_FEATURE_EXTRACTOR_H_
diff --git a/common/embedding-feature-extractor_test.cc b/common/embedding-feature-extractor_test.cc
deleted file mode 100644
index c5ed627..0000000
--- a/common/embedding-feature-extractor_test.cc
+++ /dev/null
@@ -1,142 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "common/embedding-feature-extractor.h"
-
-#include "lang_id/language-identifier-features.h"
-#include "lang_id/light-sentence-features.h"
-#include "lang_id/light-sentence.h"
-#include "lang_id/relevant-script-feature.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-class EmbeddingFeatureExtractorTest : public ::testing::Test {
- public:
- void SetUp() override {
- // Make sure all relevant features are registered:
- lang_id::ContinuousBagOfNgramsFunction::RegisterClass();
- lang_id::RelevantScriptFeature::RegisterClass();
- }
-};
-
-// Specialization of EmbeddingFeatureExtractor that extracts from LightSentence.
-class TestEmbeddingFeatureExtractor
- : public EmbeddingFeatureExtractor<lang_id::LightSentenceExtractor,
- lang_id::LightSentence> {
- public:
- const std::string ArgPrefix() const override { return "test"; }
-};
-
-TEST_F(EmbeddingFeatureExtractorTest, NoEmbeddingSpaces) {
- TaskContext context;
- context.SetParameter("test_features", "");
- context.SetParameter("test_embedding_names", "");
- context.SetParameter("test_embedding_dims", "");
- TestEmbeddingFeatureExtractor tefe;
- ASSERT_TRUE(tefe.Init(&context));
- EXPECT_EQ(tefe.NumEmbeddings(), 0);
-}
-
-TEST_F(EmbeddingFeatureExtractorTest, GoodSpec) {
- TaskContext context;
- const std::string spec =
- "continuous-bag-of-ngrams(id_dim=5000,size=3);"
- "continuous-bag-of-ngrams(id_dim=7000,size=4)";
- context.SetParameter("test_features", spec);
- context.SetParameter("test_embedding_names", "trigram;quadgram");
- context.SetParameter("test_embedding_dims", "16;24");
- TestEmbeddingFeatureExtractor tefe;
- ASSERT_TRUE(tefe.Init(&context));
- EXPECT_EQ(tefe.NumEmbeddings(), 2);
- EXPECT_EQ(tefe.EmbeddingSize(0), 5000);
- EXPECT_EQ(tefe.EmbeddingDims(0), 16);
- EXPECT_EQ(tefe.EmbeddingSize(1), 7000);
- EXPECT_EQ(tefe.EmbeddingDims(1), 24);
-}
-
-TEST_F(EmbeddingFeatureExtractorTest, MissmatchFmlVsNames) {
- TaskContext context;
- const std::string spec =
- "continuous-bag-of-ngrams(id_dim=5000,size=3);"
- "continuous-bag-of-ngrams(id_dim=7000,size=4)";
- context.SetParameter("test_features", spec);
- context.SetParameter("test_embedding_names", "trigram");
- context.SetParameter("test_embedding_dims", "16;16");
- TestEmbeddingFeatureExtractor tefe;
- ASSERT_FALSE(tefe.Init(&context));
-}
-
-TEST_F(EmbeddingFeatureExtractorTest, MissmatchFmlVsDims) {
- TaskContext context;
- const std::string spec =
- "continuous-bag-of-ngrams(id_dim=5000,size=3);"
- "continuous-bag-of-ngrams(id_dim=7000,size=4)";
- context.SetParameter("test_features", spec);
- context.SetParameter("test_embedding_names", "trigram;quadgram");
- context.SetParameter("test_embedding_dims", "16;16;32");
- TestEmbeddingFeatureExtractor tefe;
- ASSERT_FALSE(tefe.Init(&context));
-}
-
-TEST_F(EmbeddingFeatureExtractorTest, BrokenSpec) {
- TaskContext context;
- const std::string spec =
- "continuous-bag-of-ngrams(id_dim=5000;"
- "continuous-bag-of-ngrams(id_dim=7000,size=4)";
- context.SetParameter("test_features", spec);
- context.SetParameter("test_embedding_names", "trigram;quadgram");
- context.SetParameter("test_embedding_dims", "16;16");
- TestEmbeddingFeatureExtractor tefe;
- ASSERT_FALSE(tefe.Init(&context));
-}
-
-TEST_F(EmbeddingFeatureExtractorTest, MissingFeature) {
- TaskContext context;
- const std::string spec =
- "continuous-bag-of-ngrams(id_dim=5000,size=3);"
- "no-such-feature";
- context.SetParameter("test_features", spec);
- context.SetParameter("test_embedding_names", "trigram;foo");
- context.SetParameter("test_embedding_dims", "16;16");
- TestEmbeddingFeatureExtractor tefe;
- ASSERT_FALSE(tefe.Init(&context));
-}
-
-TEST_F(EmbeddingFeatureExtractorTest, MultipleFeatures) {
- TaskContext context;
- const std::string spec =
- "continuous-bag-of-ngrams(id_dim=1000,size=3);"
- "continuous-bag-of-relevant-scripts";
- context.SetParameter("test_features", spec);
- context.SetParameter("test_embedding_names", "trigram;script");
- context.SetParameter("test_embedding_dims", "8;16");
- TestEmbeddingFeatureExtractor tefe;
- ASSERT_TRUE(tefe.Init(&context));
- EXPECT_EQ(tefe.NumEmbeddings(), 2);
- EXPECT_EQ(tefe.EmbeddingSize(0), 1000);
- EXPECT_EQ(tefe.EmbeddingDims(0), 8);
-
- // continuous-bag-of-relevant-scripts has its own hard-wired vocabulary size.
- // We don't want this test to depend on that value; we just check it's bigger
- // than 0.
- EXPECT_GT(tefe.EmbeddingSize(1), 0);
- EXPECT_EQ(tefe.EmbeddingDims(1), 16);
-}
-
-} // namespace nlp_core
-} // namespace libtextclassifier
diff --git a/common/embedding-network-package.proto b/common/embedding-network-package.proto
deleted file mode 100644
index 54d47e6..0000000
--- a/common/embedding-network-package.proto
+++ /dev/null
@@ -1,29 +0,0 @@
-// Copyright (C) 2017 The Android Open Source Project
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// This file defines TaskSpec as an extension to EmbeddingNetworkProto. The
-// definition is done here rather than directly in the imported protos to keep
-// the different messages as independent as possible.
-
-syntax = "proto2";
-option optimize_for = LITE_RUNTIME;
-
-import "external/libtextclassifier/common/task-spec.proto";
-import "external/libtextclassifier/common/embedding-network.proto";
-
-package libtextclassifier.nlp_core;
-
-extend EmbeddingNetworkProto {
- optional TaskSpec task_spec_in_embedding_network_proto = 129692954;
-}
diff --git a/common/embedding-network-params-from-proto.h b/common/embedding-network-params-from-proto.h
deleted file mode 100644
index 2f2c429..0000000
--- a/common/embedding-network-params-from-proto.h
+++ /dev/null
@@ -1,245 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_FROM_PROTO_H_
-#define LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_FROM_PROTO_H_
-
-#include <algorithm>
-#include <memory>
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "common/embedding-network-package.pb.h"
-#include "common/embedding-network-params.h"
-#include "common/embedding-network.pb.h"
-#include "common/float16.h"
-#include "common/little-endian-data.h"
-#include "common/task-context.h"
-#include "common/task-spec.pb.h"
-#include "util/base/integral_types.h"
-#include "util/base/logging.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-// A wrapper class that owns and exposes an EmbeddingNetworkProto message via
-// the EmbeddingNetworkParams interface.
-//
-// The EmbeddingNetworkParams interface encapsulates the weight matrices of the
-// embeddings, hidden and softmax layers as transposed versions of their
-// counterparts in the original EmbeddingNetworkProto. The matrices in the proto
-// passed to this class' constructor must likewise already have been transposed.
-// See embedding-network-params.h for details.
-class EmbeddingNetworkParamsFromProto : public EmbeddingNetworkParams {
- public:
- // Constructor that takes ownership of the provided proto. See class-comment
- // for the requirements that certain weight matrices must satisfy.
- explicit EmbeddingNetworkParamsFromProto(
- std::unique_ptr<EmbeddingNetworkProto> proto)
- : proto_(std::move(proto)) {
- valid_ = true;
-
- // Initialize these vectors to have the required number of elements
- // regardless of quantization status. This is to support the unlikely case
- // where only some embeddings are quantized, along with the fact that
- // EmbeddingNetworkParams interface accesses them by index.
- embeddings_quant_scales_.resize(proto_->embeddings_size());
- embeddings_quant_weights_.resize(proto_->embeddings_size());
- for (int i = 0; i < proto_->embeddings_size(); ++i) {
- MatrixParams *embedding = proto_->mutable_embeddings()->Mutable(i);
- if (!embedding->is_quantized()) {
- continue;
- }
-
- bool success = FillVectorFromDataBytesInLittleEndian(
- embedding->bytes_for_quantized_values(),
- embedding->rows() * embedding->cols(),
- &(embeddings_quant_weights_[i]));
- if (!success) {
- TC_LOG(ERROR) << "Problem decoding quant_weights for embeddings #" << i;
- valid_ = false;
- }
-
- // The repeated field bytes_for_quantized_values uses a lot of memory.
- // Since it's no longer necessary (and we own the proto), we clear it.
- embedding->clear_bytes_for_quantized_values();
-
- success = FillVectorFromDataBytesInLittleEndian(
- embedding->bytes_for_col_scales(),
- embedding->rows(),
- &(embeddings_quant_scales_[i]));
- if (!success) {
- TC_LOG(ERROR) << "Problem decoding col_scales for embeddings #" << i;
- valid_ = false;
- }
-
- // See comments for clear_bytes_for_quantized_values().
- embedding->clear_bytes_for_col_scales();
- }
- }
-
- const TaskSpec *GetTaskSpec() override {
- if (!proto_) {
- return nullptr;
- }
- auto extension_id = task_spec_in_embedding_network_proto;
- if (proto_->HasExtension(extension_id)) {
- return &(proto_->GetExtension(extension_id));
- } else {
- TC_LOG(ERROR) << "Unable to get TaskSpec from EmbeddingNetworkProto";
- return nullptr;
- }
- }
-
- // Returns true if these params are valid. False otherwise (e.g., if the
- // original proto data was corrupted).
- bool is_valid() { return valid_; }
-
- protected:
- int embeddings_size() const override { return proto_->embeddings_size(); }
-
- int embeddings_num_rows(int i) const override {
- TC_DCHECK(InRange(i, embeddings_size()));
- return proto_->embeddings(i).rows();
- }
-
- int embeddings_num_cols(int i) const override {
- TC_DCHECK(InRange(i, embeddings_size()));
- return proto_->embeddings(i).cols();
- }
-
- const void *embeddings_weights(int i) const override {
- TC_DCHECK(InRange(i, embeddings_size()));
- if (proto_->embeddings(i).is_quantized()) {
- return static_cast<const void *>(embeddings_quant_weights_.at(i).data());
- } else {
- return static_cast<const void *>(proto_->embeddings(i).value().data());
- }
- }
-
- QuantizationType embeddings_quant_type(int i) const override {
- TC_DCHECK(InRange(i, embeddings_size()));
- return proto_->embeddings(i).is_quantized() ? QuantizationType::UINT8
- : QuantizationType::NONE;
- }
-
- const float16 *embeddings_quant_scales(int i) const override {
- TC_DCHECK(InRange(i, embeddings_size()));
- return proto_->embeddings(i).is_quantized()
- ? embeddings_quant_scales_.at(i).data()
- : nullptr;
- }
-
- int hidden_size() const override { return proto_->hidden_size(); }
-
- int hidden_num_rows(int i) const override {
- TC_DCHECK(InRange(i, hidden_size()));
- return proto_->hidden(i).rows();
- }
-
- int hidden_num_cols(int i) const override {
- TC_DCHECK(InRange(i, hidden_size()));
- return proto_->hidden(i).cols();
- }
-
- const void *hidden_weights(int i) const override {
- TC_DCHECK(InRange(i, hidden_size()));
- return proto_->hidden(i).value().data();
- }
-
- int hidden_bias_size() const override { return proto_->hidden_bias_size(); }
-
- int hidden_bias_num_rows(int i) const override {
- TC_DCHECK(InRange(i, hidden_bias_size()));
- return proto_->hidden_bias(i).rows();
- }
-
- int hidden_bias_num_cols(int i) const override {
- TC_DCHECK(InRange(i, hidden_bias_size()));
- return proto_->hidden_bias(i).cols();
- }
-
- const void *hidden_bias_weights(int i) const override {
- TC_DCHECK(InRange(i, hidden_bias_size()));
- return proto_->hidden_bias(i).value().data();
- }
-
- int softmax_size() const override { return proto_->has_softmax() ? 1 : 0; }
-
- int softmax_num_rows(int i) const override {
- TC_DCHECK(InRange(i, softmax_size()));
- return proto_->has_softmax() ? proto_->softmax().rows() : 0;
- }
-
- int softmax_num_cols(int i) const override {
- TC_DCHECK(InRange(i, softmax_size()));
- return proto_->has_softmax() ? proto_->softmax().cols() : 0;
- }
-
- const void *softmax_weights(int i) const override {
- TC_DCHECK(InRange(i, softmax_size()));
- return proto_->has_softmax() ? proto_->softmax().value().data() : nullptr;
- }
-
- int softmax_bias_size() const override {
- return proto_->has_softmax_bias() ? 1 : 0;
- }
-
- int softmax_bias_num_rows(int i) const override {
- TC_DCHECK(InRange(i, softmax_bias_size()));
- return proto_->has_softmax_bias() ? proto_->softmax_bias().rows() : 0;
- }
-
- int softmax_bias_num_cols(int i) const override {
- TC_DCHECK(InRange(i, softmax_bias_size()));
- return proto_->has_softmax_bias() ? proto_->softmax_bias().cols() : 0;
- }
-
- const void *softmax_bias_weights(int i) const override {
- TC_DCHECK(InRange(i, softmax_bias_size()));
- return proto_->has_softmax_bias() ? proto_->softmax_bias().value().data()
- : nullptr;
- }
-
- int embedding_num_features_size() const override {
- return proto_->embedding_num_features_size();
- }
-
- int embedding_num_features(int i) const override {
- TC_DCHECK(InRange(i, embedding_num_features_size()));
- return proto_->embedding_num_features(i);
- }
-
- private:
- std::unique_ptr<EmbeddingNetworkProto> proto_;
-
- // True if these params are valid. May be false if the original proto was
- // corrupted. We prefer to set this to false to CHECK-failing.
- bool valid_;
-
- // When the embeddings are quantized, these members are used to store their
- // numeric values using the types expected by the rest of the class. Due to
- // technical reasons, the proto stores this info using larger types (i.e.,
- // more bits).
- std::vector<std::vector<float16>> embeddings_quant_scales_;
- std::vector<std::vector<uint8>> embeddings_quant_weights_;
-};
-
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_FROM_PROTO_H_
diff --git a/common/embedding-network-params.h b/common/embedding-network-params.h
deleted file mode 100755
index ee2d9dc..0000000
--- a/common/embedding-network-params.h
+++ /dev/null
@@ -1,325 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_H_
-#define LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_H_
-
-#include <algorithm>
-#include <string>
-
-#include "common/float16.h"
-#include "common/task-context.h"
-#include "common/task-spec.pb.h"
-#include "util/base/logging.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-enum class QuantizationType { NONE = 0, UINT8 };
-
-// API for accessing parameters for a feed-forward neural network with
-// embeddings.
-//
-// Note: this API is closely related to embedding-network.proto. The reason we
-// have a separate API is that the proto may not be the only way of packaging
-// these parameters.
-class EmbeddingNetworkParams {
- public:
- virtual ~EmbeddingNetworkParams() {}
-
- // **** High-level API.
-
- // Simple representation of a matrix. This small struct that doesn't own any
- // resource intentionally supports copy / assign, to simplify our APIs.
- struct Matrix {
- // Number of rows.
- int rows;
-
- // Number of columns.
- int cols;
-
- QuantizationType quant_type;
-
- // Pointer to matrix elements, in row-major order
- // (https://en.wikipedia.org/wiki/Row-major_order) Not owned.
- const void *elements;
-
- // Quantization scales: one scale for each row.
- const float16 *quant_scales;
- };
-
- // Returns number of embedding spaces.
- int GetNumEmbeddingSpaces() const {
- if (embeddings_size() != embedding_num_features_size()) {
- TC_LOG(ERROR) << "Embedding spaces mismatch " << embeddings_size()
- << " != " << embedding_num_features_size();
- }
- return std::max(0,
- std::min(embeddings_size(), embedding_num_features_size()));
- }
-
- // Returns embedding matrix for the i-th embedding space.
- //
- // NOTE: i must be in [0, GetNumEmbeddingSpaces()). Undefined behavior
- // otherwise.
- Matrix GetEmbeddingMatrix(int i) const {
- TC_DCHECK(InRange(i, embeddings_size()));
- Matrix matrix;
- matrix.rows = embeddings_num_rows(i);
- matrix.cols = embeddings_num_cols(i);
- matrix.elements = embeddings_weights(i);
- matrix.quant_type = embeddings_quant_type(i);
- matrix.quant_scales = embeddings_quant_scales(i);
- return matrix;
- }
-
- // Returns number of features in i-th embedding space.
- //
- // NOTE: i must be in [0, GetNumEmbeddingSpaces()). Undefined behavior
- // otherwise.
- int GetNumFeaturesInEmbeddingSpace(int i) const {
- TC_DCHECK(InRange(i, embedding_num_features_size()));
- return std::max(0, embedding_num_features(i));
- }
-
- // Returns number of hidden layers in the neural network. Each such layer has
- // weight matrix and a bias vector (a matrix with one column).
- int GetNumHiddenLayers() const {
- if (hidden_size() != hidden_bias_size()) {
- TC_LOG(ERROR) << "Hidden layer mismatch " << hidden_size()
- << " != " << hidden_bias_size();
- }
- return std::max(0, std::min(hidden_size(), hidden_bias_size()));
- }
-
- // Returns weight matrix for i-th hidden layer.
- //
- // NOTE: i must be in [0, GetNumHiddenLayers()). Undefined behavior
- // otherwise.
- Matrix GetHiddenLayerMatrix(int i) const {
- TC_DCHECK(InRange(i, hidden_size()));
- Matrix matrix;
- matrix.rows = hidden_num_rows(i);
- matrix.cols = hidden_num_cols(i);
-
- // Quantization not supported here.
- matrix.quant_type = QuantizationType::NONE;
- matrix.elements = hidden_weights(i);
- return matrix;
- }
-
- // Returns bias matrix for i-th hidden layer. Technically a Matrix, but we
- // expect it to be a vector (i.e., num cols is 1).
- //
- // NOTE: i must be in [0, GetNumHiddenLayers()). Undefined behavior
- // otherwise.
- Matrix GetHiddenLayerBias(int i) const {
- TC_DCHECK(InRange(i, hidden_bias_size()));
- Matrix matrix;
- matrix.rows = hidden_bias_num_rows(i);
- matrix.cols = hidden_bias_num_cols(i);
-
- // Quantization not supported here.
- matrix.quant_type = QuantizationType::NONE;
- matrix.elements = hidden_bias_weights(i);
- return matrix;
- }
-
- // Returns true if a softmax layer exists.
- bool HasSoftmaxLayer() const {
- if (softmax_size() != softmax_bias_size()) {
- TC_LOG(ERROR) << "Softmax layer mismatch " << softmax_size()
- << " != " << softmax_bias_size();
- }
- return (softmax_size() == 1) && (softmax_bias_size() == 1);
- }
-
- // Returns weight matrix for the softmax layer.
- //
- // NOTE: Should be called only if HasSoftmaxLayer() is true. Undefined
- // behavior otherwise.
- Matrix GetSoftmaxMatrix() const {
- TC_DCHECK(softmax_size() == 1);
- Matrix matrix;
- matrix.rows = softmax_num_rows(0);
- matrix.cols = softmax_num_cols(0);
-
- // Quantization not supported here.
- matrix.quant_type = QuantizationType::NONE;
- matrix.elements = softmax_weights(0);
- return matrix;
- }
-
- // Returns bias for the softmax layer. Technically a Matrix, but we expect it
- // to be a row/column vector (i.e., num cols is 1).
- //
- // NOTE: Should be called only if HasSoftmaxLayer() is true. Undefined
- // behavior otherwise.
- Matrix GetSoftmaxBias() const {
- TC_DCHECK(softmax_bias_size() == 1);
- Matrix matrix;
- matrix.rows = softmax_bias_num_rows(0);
- matrix.cols = softmax_bias_num_cols(0);
-
- // Quantization not supported here.
- matrix.quant_type = QuantizationType::NONE;
- matrix.elements = softmax_bias_weights(0);
- return matrix;
- }
-
- // Updates the EmbeddingNetwork-related parameters from task_context. Returns
- // true on success, false on error.
- virtual bool UpdateTaskContextParameters(TaskContext *task_context) {
- const TaskSpec *task_spec = GetTaskSpec();
- if (task_spec == nullptr) {
- TC_LOG(ERROR) << "Unable to get TaskSpec";
- return false;
- }
- for (const TaskSpec::Parameter ¶meter : task_spec->parameter()) {
- task_context->SetParameter(parameter.name(), parameter.value());
- }
- return true;
- }
-
- // Returns a pointer to a TaskSpec with the EmbeddingNetwork-related
- // parameters. Returns nullptr in case of problems. Ownership with the
- // returned pointer is *not* transfered to the caller.
- virtual const TaskSpec *GetTaskSpec() {
- TC_LOG(ERROR) << "Not implemented";
- return nullptr;
- }
-
- protected:
- // **** Low-level API.
- //
- // * Most low-level API methods are documented by giving an equivalent
- // function call on proto, the original proto (of type
- // EmbeddingNetworkProto) which was used to generate the C++ code.
- //
- // * To simplify our generation code, optional proto fields of message type
- // are treated as repeated fields with 0 or 1 instances. As such, we have
- // *_size() methods for such optional fields: they return 0 or 1.
- //
- // * "transpose(M)" denotes the transpose of a matrix M.
- //
- // * Behavior is undefined when trying to retrieve a piece of data that does
- // not exist: e.g., embeddings_num_rows(5) if embeddings_size() == 2.
-
- // ** Access methods for repeated MatrixParams embeddings.
- //
- // Returns proto.embeddings_size().
- virtual int embeddings_size() const = 0;
-
- // Returns number of rows of transpose(proto.embeddings(i)).
- virtual int embeddings_num_rows(int i) const = 0;
-
- // Returns number of columns of transpose(proto.embeddings(i)).
- virtual int embeddings_num_cols(int i) const = 0;
-
- // Returns pointer to elements of transpose(proto.embeddings(i)), in row-major
- // order. NOTE: for unquantized embeddings, this returns a pointer to float;
- // for quantized embeddings, this returns a pointer to uint8.
- virtual const void *embeddings_weights(int i) const = 0;
-
- virtual QuantizationType embeddings_quant_type(int i) const {
- return QuantizationType::NONE;
- }
-
- virtual const float16 *embeddings_quant_scales(int i) const {
- return nullptr;
- }
-
- // ** Access methods for repeated MatrixParams hidden.
- //
- // Returns embedding_network_proto.hidden_size().
- virtual int hidden_size() const = 0;
-
- // Returns embedding_network_proto.hidden(i).rows().
- virtual int hidden_num_rows(int i) const = 0;
-
- // Returns embedding_network_proto.hidden(i).rows().
- virtual int hidden_num_cols(int i) const = 0;
-
- // Returns pointer to beginning of array of floats with all values from
- // embedding_network_proto.hidden(i).
- virtual const void *hidden_weights(int i) const = 0;
-
- // ** Access methods for repeated MatrixParams hidden_bias.
- //
- // Returns proto.hidden_bias_size().
- virtual int hidden_bias_size() const = 0;
-
- // Returns number of rows of proto.hidden_bias(i).
- virtual int hidden_bias_num_rows(int i) const = 0;
-
- // Returns number of columns of proto.hidden_bias(i).
- virtual int hidden_bias_num_cols(int i) const = 0;
-
- // Returns pointer to elements of proto.hidden_bias(i), in row-major order.
- virtual const void *hidden_bias_weights(int i) const = 0;
-
- // ** Access methods for optional MatrixParams softmax.
- //
- // Returns 1 if proto has optional field softmax, 0 otherwise.
- virtual int softmax_size() const = 0;
-
- // Returns number of rows of transpose(proto.softmax()).
- virtual int softmax_num_rows(int i) const = 0;
-
- // Returns number of columns of transpose(proto.softmax()).
- virtual int softmax_num_cols(int i) const = 0;
-
- // Returns pointer to elements of transpose(proto.softmax()), in row-major
- // order.
- virtual const void *softmax_weights(int i) const = 0;
-
- // ** Access methods for optional MatrixParams softmax_bias.
- //
- // Returns 1 if proto has optional field softmax_bias, 0 otherwise.
- virtual int softmax_bias_size() const = 0;
-
- // Returns number of rows of proto.softmax_bias().
- virtual int softmax_bias_num_rows(int i) const = 0;
-
- // Returns number of columns of proto.softmax_bias().
- virtual int softmax_bias_num_cols(int i) const = 0;
-
- // Returns pointer to elements of proto.softmax_bias(), in row-major order.
- virtual const void *softmax_bias_weights(int i) const = 0;
-
- // ** Access methods for repeated int32 embedding_num_features.
- //
- // Returns proto.embedding_num_features_size().
- virtual int embedding_num_features_size() const = 0;
-
- // Returns proto.embedding_num_features(i).
- virtual int embedding_num_features(int i) const = 0;
-
- // Returns true if and only if index is in range [0, size). Log an error
- // message otherwise.
- static bool InRange(int index, int size) {
- if ((index < 0) || (index >= size)) {
- TC_LOG(ERROR) << "Index " << index << " outside [0, " << size << ")";
- return false;
- }
- return true;
- }
-}; // class EmbeddingNetworkParams
-
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_H_
diff --git a/common/embedding-network.cc b/common/embedding-network.cc
deleted file mode 100644
index b27cda3..0000000
--- a/common/embedding-network.cc
+++ /dev/null
@@ -1,380 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "common/embedding-network.h"
-
-#include <math.h>
-
-#include "common/simple-adder.h"
-#include "util/base/integral_types.h"
-#include "util/base/logging.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-namespace {
-
-// Returns true if and only if matrix does not use any quantization.
-bool CheckNoQuantization(const EmbeddingNetworkParams::Matrix &matrix) {
- if (matrix.quant_type != QuantizationType::NONE) {
- TC_LOG(ERROR) << "Unsupported quantization";
- TC_DCHECK(false); // Crash in debug mode.
- return false;
- }
- return true;
-}
-
-// Initializes a Matrix object with the parameters from the MatrixParams
-// source_matrix. source_matrix should not use quantization.
-//
-// Returns true on success, false on error.
-bool InitNonQuantizedMatrix(const EmbeddingNetworkParams::Matrix &source_matrix,
- EmbeddingNetwork::Matrix *mat) {
- mat->resize(source_matrix.rows);
-
- // Before we access the weights as floats, we need to check that they are
- // really floats, i.e., no quantization is used.
- if (!CheckNoQuantization(source_matrix)) return false;
- const float *weights =
- reinterpret_cast<const float *>(source_matrix.elements);
- for (int r = 0; r < source_matrix.rows; ++r) {
- (*mat)[r] = EmbeddingNetwork::VectorWrapper(weights, source_matrix.cols);
- weights += source_matrix.cols;
- }
- return true;
-}
-
-// Initializes a VectorWrapper object with the parameters from the MatrixParams
-// source_matrix. source_matrix should have exactly one column and should not
-// use quantization.
-//
-// Returns true on success, false on error.
-bool InitNonQuantizedVector(const EmbeddingNetworkParams::Matrix &source_matrix,
- EmbeddingNetwork::VectorWrapper *vector) {
- if (source_matrix.cols != 1) {
- TC_LOG(ERROR) << "wrong #cols " << source_matrix.cols;
- return false;
- }
- if (!CheckNoQuantization(source_matrix)) {
- TC_LOG(ERROR) << "unsupported quantization";
- return false;
- }
- // Before we access the weights as floats, we need to check that they are
- // really floats, i.e., no quantization is used.
- if (!CheckNoQuantization(source_matrix)) return false;
- const float *weights =
- reinterpret_cast<const float *>(source_matrix.elements);
- *vector = EmbeddingNetwork::VectorWrapper(weights, source_matrix.rows);
- return true;
-}
-
-// Computes y = weights * Relu(x) + b where Relu is optionally applied.
-template <typename ScaleAdderClass>
-bool SparseReluProductPlusBias(bool apply_relu,
- const EmbeddingNetwork::Matrix &weights,
- const EmbeddingNetwork::VectorWrapper &b,
- const VectorSpan<float> &x,
- EmbeddingNetwork::Vector *y) {
- // Check that dimensions match.
- if ((x.size() != weights.size()) || weights.empty()) {
- TC_LOG(ERROR) << x.size() << " != " << weights.size();
- return false;
- }
- if (weights[0].size() != b.size()) {
- TC_LOG(ERROR) << weights[0].size() << " != " << b.size();
- return false;
- }
-
- y->assign(b.data(), b.data() + b.size());
- ScaleAdderClass adder(y->data(), y->size());
-
- const int x_size = x.size();
- for (int i = 0; i < x_size; ++i) {
- const float &scale = x[i];
- if (apply_relu) {
- if (scale > 0) {
- adder.LazyScaleAdd(weights[i].data(), scale);
- }
- } else {
- adder.LazyScaleAdd(weights[i].data(), scale);
- }
- }
- return true;
-}
-} // namespace
-
-bool EmbeddingNetwork::ConcatEmbeddings(
- const std::vector<FeatureVector> &feature_vectors, Vector *concat) const {
- concat->resize(concat_layer_size_);
-
- // Invariant 1: feature_vectors contains exactly one element for each
- // embedding space. That element is itself a FeatureVector, which may be
- // empty, but it should be there.
- if (feature_vectors.size() != embedding_matrices_.size()) {
- TC_LOG(ERROR) << feature_vectors.size()
- << " != " << embedding_matrices_.size();
- return false;
- }
-
- // "es_index" stands for "embedding space index".
- for (int es_index = 0; es_index < feature_vectors.size(); ++es_index) {
- // Access is safe by es_index loop bounds and Invariant 1.
- EmbeddingMatrix *const embedding_matrix =
- embedding_matrices_[es_index].get();
- if (embedding_matrix == nullptr) {
- // Should not happen, hence our terse log error message.
- TC_LOG(ERROR) << es_index;
- return false;
- }
-
- // Access is safe due to es_index loop bounds.
- const FeatureVector &feature_vector = feature_vectors[es_index];
-
- // Access is safe by es_index loop bounds, Invariant 1, and Invariant 2.
- const int concat_offset = concat_offset_[es_index];
-
- if (!GetEmbeddingInternal(feature_vector, embedding_matrix, concat_offset,
- concat->data(), concat->size())) {
- TC_LOG(ERROR) << es_index;
- return false;
- }
- }
- return true;
-}
-
-bool EmbeddingNetwork::GetEmbedding(const FeatureVector &feature_vector,
- int es_index, float *embedding) const {
- EmbeddingMatrix *const embedding_matrix = embedding_matrices_[es_index].get();
- if (embedding_matrix == nullptr) {
- // Should not happen, hence our terse log error message.
- TC_LOG(ERROR) << es_index;
- return false;
- }
- return GetEmbeddingInternal(feature_vector, embedding_matrix, 0, embedding,
- embedding_matrices_[es_index]->dim());
-}
-
-bool EmbeddingNetwork::GetEmbeddingInternal(
- const FeatureVector &feature_vector,
- EmbeddingMatrix *const embedding_matrix, const int concat_offset,
- float *concat, int concat_size) const {
- const int embedding_dim = embedding_matrix->dim();
- const bool is_quantized =
- embedding_matrix->quant_type() != QuantizationType::NONE;
- const int num_features = feature_vector.size();
- for (int fi = 0; fi < num_features; ++fi) {
- // Both accesses below are safe due to loop bounds for fi.
- const FeatureType *feature_type = feature_vector.type(fi);
- const FeatureValue feature_value = feature_vector.value(fi);
- const int feature_offset =
- concat_offset + feature_type->base() * embedding_dim;
-
- // Code below updates max(0, embedding_dim) elements from concat, starting
- // with index feature_offset. Check below ensures these updates are safe.
- if ((feature_offset < 0) ||
- (feature_offset + embedding_dim > concat_size)) {
- TC_LOG(ERROR) << fi << ": " << feature_offset << " " << embedding_dim
- << " " << concat_size;
- return false;
- }
-
- // Pointer to float / uint8 weights for relevant embedding.
- const void *embedding_data;
-
- // Multiplier for each embedding weight.
- float multiplier;
-
- if (feature_type->is_continuous()) {
- // Continuous features (encoded as FloatFeatureValue).
- FloatFeatureValue float_feature_value(feature_value);
- const int id = float_feature_value.id;
- embedding_matrix->get_embedding(id, &embedding_data, &multiplier);
- multiplier *= float_feature_value.weight;
- } else {
- // Discrete features: every present feature has implicit value 1.0.
- // Hence, after we grab the multiplier below, we don't multiply it by
- // any weight.
- embedding_matrix->get_embedding(feature_value, &embedding_data,
- &multiplier);
- }
-
- // Weighted embeddings will be added starting from this address.
- float *concat_ptr = concat + feature_offset;
-
- if (is_quantized) {
- const uint8 *quant_weights =
- reinterpret_cast<const uint8 *>(embedding_data);
- for (int i = 0; i < embedding_dim; ++i, ++quant_weights, ++concat_ptr) {
- // 128 is bias for UINT8 quantization, only one we currently support.
- *concat_ptr += (static_cast<int>(*quant_weights) - 128) * multiplier;
- }
- } else {
- const float *weights = reinterpret_cast<const float *>(embedding_data);
- for (int i = 0; i < embedding_dim; ++i, ++weights, ++concat_ptr) {
- *concat_ptr += *weights * multiplier;
- }
- }
- }
- return true;
-}
-
-bool EmbeddingNetwork::ComputeLogits(const VectorSpan<float> &input,
- Vector *scores) const {
- return EmbeddingNetwork::ComputeLogitsInternal(input, scores);
-}
-
-bool EmbeddingNetwork::ComputeLogits(const Vector &input,
- Vector *scores) const {
- return EmbeddingNetwork::ComputeLogitsInternal(input, scores);
-}
-
-bool EmbeddingNetwork::ComputeLogitsInternal(const VectorSpan<float> &input,
- Vector *scores) const {
- return FinishComputeFinalScoresInternal<SimpleAdder>(input, scores);
-}
-
-template <typename ScaleAdderClass>
-bool EmbeddingNetwork::FinishComputeFinalScoresInternal(
- const VectorSpan<float> &input, Vector *scores) const {
- // This vector serves as an alternating storage for activations of the
- // different layers. We can't use just one vector here because all of the
- // activations of the previous layer are needed for computation of
- // activations of the next one.
- std::vector<Vector> h_storage(2);
-
- // Compute pre-logits activations.
- VectorSpan<float> h_in(input);
- Vector *h_out;
- for (int i = 0; i < hidden_weights_.size(); ++i) {
- const bool apply_relu = i > 0;
- h_out = &(h_storage[i % 2]);
- h_out->resize(hidden_bias_[i].size());
- if (!SparseReluProductPlusBias<ScaleAdderClass>(
- apply_relu, hidden_weights_[i], hidden_bias_[i], h_in, h_out)) {
- return false;
- }
- h_in = VectorSpan<float>(*h_out);
- }
-
- // Compute logit scores.
- if (!SparseReluProductPlusBias<ScaleAdderClass>(
- true, softmax_weights_, softmax_bias_, h_in, scores)) {
- return false;
- }
-
- return true;
-}
-
-bool EmbeddingNetwork::ComputeFinalScores(
- const std::vector<FeatureVector> &features, Vector *scores) const {
- return ComputeFinalScores(features, {}, scores);
-}
-
-bool EmbeddingNetwork::ComputeFinalScores(
- const std::vector<FeatureVector> &features,
- const std::vector<float> extra_inputs, Vector *scores) const {
- // If we haven't successfully initialized, return without doing anything.
- if (!is_valid()) return false;
-
- Vector concat;
- if (!ConcatEmbeddings(features, &concat)) return false;
-
- if (!extra_inputs.empty()) {
- concat.reserve(concat.size() + extra_inputs.size());
- for (int i = 0; i < extra_inputs.size(); i++) {
- concat.push_back(extra_inputs[i]);
- }
- }
-
- scores->resize(softmax_bias_.size());
- return ComputeLogits(concat, scores);
-}
-
-EmbeddingNetwork::EmbeddingNetwork(const EmbeddingNetworkParams *model) {
- // We'll set valid_ to true only if construction is successful. If we detect
- // an error along the way, we log an informative message and return early, but
- // we do not crash.
- valid_ = false;
-
- // Fill embedding_matrices_, concat_offset_, and concat_layer_size_.
- const int num_embedding_spaces = model->GetNumEmbeddingSpaces();
- int offset_sum = 0;
- for (int i = 0; i < num_embedding_spaces; ++i) {
- concat_offset_.push_back(offset_sum);
- const EmbeddingNetworkParams::Matrix matrix = model->GetEmbeddingMatrix(i);
- if (matrix.quant_type != QuantizationType::UINT8) {
- TC_LOG(ERROR) << "Unsupported quantization for embedding #" << i << ": "
- << static_cast<int>(matrix.quant_type);
- return;
- }
-
- // There is no way to accomodate an empty embedding matrix. E.g., there is
- // no way for get_embedding to return something that can be read safely.
- // Hence, we catch that error here and return early.
- if (matrix.rows == 0) {
- TC_LOG(ERROR) << "Empty embedding matrix #" << i;
- return;
- }
- embedding_matrices_.emplace_back(new EmbeddingMatrix(matrix));
- const int embedding_dim = embedding_matrices_.back()->dim();
- offset_sum += embedding_dim * model->GetNumFeaturesInEmbeddingSpace(i);
- }
- concat_layer_size_ = offset_sum;
-
- // Invariant 2 (trivial by the code above).
- TC_DCHECK_EQ(concat_offset_.size(), embedding_matrices_.size());
-
- const int num_hidden_layers = model->GetNumHiddenLayers();
- if (num_hidden_layers < 1) {
- TC_LOG(ERROR) << "Wrong number of hidden layers: " << num_hidden_layers;
- return;
- }
- hidden_weights_.resize(num_hidden_layers);
- hidden_bias_.resize(num_hidden_layers);
-
- for (int i = 0; i < num_hidden_layers; ++i) {
- const EmbeddingNetworkParams::Matrix matrix =
- model->GetHiddenLayerMatrix(i);
- const EmbeddingNetworkParams::Matrix bias = model->GetHiddenLayerBias(i);
- if (!InitNonQuantizedMatrix(matrix, &hidden_weights_[i]) ||
- !InitNonQuantizedVector(bias, &hidden_bias_[i])) {
- TC_LOG(ERROR) << "Bad hidden layer #" << i;
- return;
- }
- }
-
- if (!model->HasSoftmaxLayer()) {
- TC_LOG(ERROR) << "Missing softmax layer";
- return;
- }
- const EmbeddingNetworkParams::Matrix softmax = model->GetSoftmaxMatrix();
- const EmbeddingNetworkParams::Matrix softmax_bias = model->GetSoftmaxBias();
- if (!InitNonQuantizedMatrix(softmax, &softmax_weights_) ||
- !InitNonQuantizedVector(softmax_bias, &softmax_bias_)) {
- TC_LOG(ERROR) << "Bad softmax layer";
- return;
- }
-
- // Everything looks good.
- valid_ = true;
-}
-
-int EmbeddingNetwork::EmbeddingSize(int es_index) const {
- return embedding_matrices_[es_index]->dim();
-}
-
-} // namespace nlp_core
-} // namespace libtextclassifier
diff --git a/common/embedding-network.h b/common/embedding-network.h
deleted file mode 100644
index a02c6ea..0000000
--- a/common/embedding-network.h
+++ /dev/null
@@ -1,246 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_H_
-#define LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_H_
-
-#include <memory>
-#include <vector>
-
-#include "common/embedding-network-params.h"
-#include "common/feature-extractor.h"
-#include "common/vector-span.h"
-#include "util/base/integral_types.h"
-#include "util/base/logging.h"
-#include "util/base/macros.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-// Classifier using a hand-coded feed-forward neural network.
-//
-// No gradient computation, just inference.
-//
-// Classification works as follows:
-//
-// Discrete features -> Embeddings -> Concatenation -> Hidden+ -> Softmax
-//
-// In words: given some discrete features, this class extracts the embeddings
-// for these features, concatenates them, passes them through one or two hidden
-// layers (each layer uses Relu) and next through a softmax layer that computes
-// an unnormalized score for each possible class. Note: there is always a
-// softmax layer.
-class EmbeddingNetwork {
- public:
- // Class used to represent an embedding matrix. Each row is the embedding on
- // a vocabulary element. Number of columns = number of embedding dimensions.
- class EmbeddingMatrix {
- public:
- explicit EmbeddingMatrix(const EmbeddingNetworkParams::Matrix source_matrix)
- : rows_(source_matrix.rows),
- cols_(source_matrix.cols),
- quant_type_(source_matrix.quant_type),
- data_(source_matrix.elements),
- row_size_in_bytes_(GetRowSizeInBytes(cols_, quant_type_)),
- quant_scales_(source_matrix.quant_scales) {}
-
- // Returns vocabulary size; one embedding for each vocabulary element.
- int size() const { return rows_; }
-
- // Returns number of weights in embedding of each vocabulary element.
- int dim() const { return cols_; }
-
- // Returns quantization type for this embedding matrix.
- QuantizationType quant_type() const { return quant_type_; }
-
- // Gets embedding for k-th vocabulary element: on return, sets *data to
- // point to the embedding weights and *scale to the quantization scale (1.0
- // if no quantization).
- void get_embedding(int k, const void **data, float *scale) const {
- if ((k < 0) || (k >= size())) {
- TC_LOG(ERROR) << "Index outside [0, " << size() << "): " << k;
-
- // In debug mode, crash. In prod, pretend that k is 0.
- TC_DCHECK(false);
- k = 0;
- }
- *data = reinterpret_cast<const char *>(data_) + k * row_size_in_bytes_;
- if (quant_type_ == QuantizationType::NONE) {
- *scale = 1.0;
- } else {
- *scale = Float16To32(quant_scales_[k]);
- }
- }
-
- private:
- static int GetRowSizeInBytes(int cols, QuantizationType quant_type) {
- switch (quant_type) {
- case QuantizationType::NONE:
- return cols * sizeof(float);
- case QuantizationType::UINT8:
- return cols * sizeof(uint8);
- default:
- TC_LOG(ERROR) << "Unknown quant type: "
- << static_cast<int>(quant_type);
- return 0;
- }
- }
-
- // Vocabulary size.
- const int rows_;
-
- // Number of elements in each embedding.
- const int cols_;
-
- const QuantizationType quant_type_;
-
- // Pointer to the embedding weights, in row-major order. This is a pointer
- // to an array of floats / uint8, depending on the quantization type.
- // Not owned.
- const void *const data_;
-
- // Number of bytes for one row. Used to jump to next row in data_.
- const int row_size_in_bytes_;
-
- // Pointer to quantization scales. nullptr if no quantization. Otherwise,
- // quant_scales_[i] is scale for embedding of i-th vocabulary element.
- const float16 *const quant_scales_;
-
- TC_DISALLOW_COPY_AND_ASSIGN(EmbeddingMatrix);
- };
-
- // An immutable vector that doesn't own the memory that stores the underlying
- // floats. Can be used e.g., as a wrapper around model weights stored in the
- // static memory.
- class VectorWrapper {
- public:
- VectorWrapper() : VectorWrapper(nullptr, 0) {}
-
- // Constructs a vector wrapper around the size consecutive floats that start
- // at address data. Note: the underlying data should be alive for at least
- // the lifetime of this VectorWrapper object. That's trivially true if data
- // points to statically allocated data :)
- VectorWrapper(const float *data, int size) : data_(data), size_(size) {}
-
- int size() const { return size_; }
-
- const float *data() const { return data_; }
-
- private:
- const float *data_; // Not owned.
- int size_;
-
- // Doesn't own anything, so it can be copied and assigned at will :)
- };
-
- typedef std::vector<VectorWrapper> Matrix;
- typedef std::vector<float> Vector;
-
- // Constructs an embedding network using the parameters from model.
- //
- // Note: model should stay alive for at least the lifetime of this
- // EmbeddingNetwork object.
- explicit EmbeddingNetwork(const EmbeddingNetworkParams *model);
-
- virtual ~EmbeddingNetwork() {}
-
- // Returns true if this EmbeddingNetwork object has been correctly constructed
- // and is ready to use. Idea: in case of errors, mark this EmbeddingNetwork
- // object as invalid, but do not crash.
- bool is_valid() const { return valid_; }
-
- // Runs forward computation to fill scores with unnormalized output unit
- // scores. This is useful for making predictions.
- //
- // Returns true on success, false on error (e.g., if !is_valid()).
- bool ComputeFinalScores(const std::vector<FeatureVector> &features,
- Vector *scores) const;
-
- // Same as above, but allows specification of extra neural network inputs that
- // will be appended to the embedding vector build from features.
- bool ComputeFinalScores(const std::vector<FeatureVector> &features,
- const std::vector<float> extra_inputs,
- Vector *scores) const;
-
- // Constructs the concatenated input embedding vector in place in output
- // vector concat. Returns true on success, false on error.
- bool ConcatEmbeddings(const std::vector<FeatureVector> &features,
- Vector *concat) const;
-
- // Sums embeddings for all features from |feature_vector| and adds result
- // to values from the array pointed-to by |output|. Embeddings for continuous
- // features are weighted by the feature weight.
- //
- // NOTE: output should point to an array of EmbeddingSize(es_index) floats.
- bool GetEmbedding(const FeatureVector &feature_vector, int es_index,
- float *embedding) const;
-
- // Runs the feed-forward neural network for |input| and computes logits for
- // softmax layer.
- bool ComputeLogits(const Vector &input, Vector *scores) const;
-
- // Same as above but uses a view of the feature vector.
- bool ComputeLogits(const VectorSpan<float> &input, Vector *scores) const;
-
- // Returns the size (the number of columns) of the embedding space es_index.
- int EmbeddingSize(int es_index) const;
-
- protected:
- // Builds an embedding for given feature vector, and places it from
- // concat_offset to the concat vector.
- bool GetEmbeddingInternal(const FeatureVector &feature_vector,
- EmbeddingMatrix *embedding_matrix,
- int concat_offset, float *concat,
- int embedding_size) const;
-
- // Templated function that computes the logit scores given the concatenated
- // input embeddings.
- bool ComputeLogitsInternal(const VectorSpan<float> &concat,
- Vector *scores) const;
-
- // Computes the softmax scores (prior to normalization) from the concatenated
- // representation. Returns true on success, false on error.
- template <typename ScaleAdderClass>
- bool FinishComputeFinalScoresInternal(const VectorSpan<float> &concat,
- Vector *scores) const;
-
- // Set to true on successful construction, false otherwise.
- bool valid_ = false;
-
- // Network parameters.
-
- // One weight matrix for each embedding space.
- std::vector<std::unique_ptr<EmbeddingMatrix>> embedding_matrices_;
-
- // concat_offset_[i] is the input layer offset for i-th embedding space.
- std::vector<int> concat_offset_;
-
- // Size of the input ("concatenation") layer.
- int concat_layer_size_;
-
- // One weight matrix and one vector of bias weights for each hiden layer.
- std::vector<Matrix> hidden_weights_;
- std::vector<VectorWrapper> hidden_bias_;
-
- // Weight matrix and bias vector for the softmax layer.
- Matrix softmax_weights_;
- VectorWrapper softmax_bias_;
-};
-
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_H_
diff --git a/common/embedding-network.proto b/common/embedding-network.proto
deleted file mode 100644
index ce30b11..0000000
--- a/common/embedding-network.proto
+++ /dev/null
@@ -1,90 +0,0 @@
-// Copyright (C) 2017 The Android Open Source Project
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Protos for performing inference with an EmbeddingNetwork.
-
-syntax = "proto2";
-option optimize_for = LITE_RUNTIME;
-
-package libtextclassifier.nlp_core;
-
-// Wrapper for storing a matrix of parameters. These are stored in row-major
-// order.
-message MatrixParams {
- optional int32 rows = 1; // # of rows in the matrix
- optional int32 cols = 2; // # of columns in the matrix
-
- // Non-quantized matrix entries.
- repeated float value = 3 [packed = true];
-
- // Whether the matrix is quantized.
- optional bool is_quantized = 4 [default = false];
-
- // Bytes for all quantized values. Each value (see "repeated float value"
- // field) is quantized to an uint8 (1 byte) value, and all these bytes are
- // concatenated into the string from this field.
- optional bytes bytes_for_quantized_values = 7;
-
- // Bytes for all scale factors for dequantizing the values. The quantization
- // process generates a float16 scale factor for each column. The 2 bytes for
- // each such float16 are put in little-endian order (least significant byte
- // first) and next all these pairs of bytes are concatenated into the string
- // from this field.
- optional bytes bytes_for_col_scales = 8;
-
- reserved 5, 6;
-}
-
-// Stores all parameters for a given EmbeddingNetwork. This can either be a
-// EmbeddingNetwork or a PrecomputedEmbeddingNetwork: for precomputed networks,
-// the embedding weights are actually the activations of the first hidden layer
-// *before* the bias is added and the non-linear transform is applied.
-//
-// Thus, for PrecomputedEmbeddingNetwork storage, hidden layers are stored
-// starting from the second hidden layer, while biases are stored for every
-// hidden layer.
-message EmbeddingNetworkProto {
- // Embeddings and hidden layers. Note that if is_precomputed == true, then the
- // embeddings should store the activations of the first hidden layer, so we
- // must have hidden_bias_size() == hidden_size() + 1 (we store weights for
- // first hidden layer bias, but no the layer itself.)
- repeated MatrixParams embeddings = 1;
- repeated MatrixParams hidden = 2;
- repeated MatrixParams hidden_bias = 3;
-
- // Final layer of the network.
- optional MatrixParams softmax = 4;
- optional MatrixParams softmax_bias = 5;
-
- // Element i of the repeated field below indicates number of features that use
- // the i-th embedding space.
- repeated int32 embedding_num_features = 7;
-
- // Whether or not this is intended to store a precomputed network.
- optional bool is_precomputed = 11 [default = false];
-
- // True if this EmbeddingNetworkProto can be used for inference with no
- // additional matrix transposition.
- //
- // Given an EmbeddingNetworkProto produced by a Neurosis training pipeline, we
- // have to transpose a few matrices (e.g., the embedding matrices) before we
- // can perform inference. When we do so, we negate this flag. Note: we don't
- // simply set this to true: transposing twice takes us to the original state.
- optional bool is_transposed = 12 [default = false];
-
- // Allow extensions.
- extensions 100 to max;
-
- reserved 6, 8, 9, 10;
-}
diff --git a/common/embedding-network_test.cc b/common/embedding-network_test.cc
deleted file mode 100644
index 026ec17..0000000
--- a/common/embedding-network_test.cc
+++ /dev/null
@@ -1,87 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "common/embedding-network.h"
-#include "common/embedding-network-params-from-proto.h"
-#include "common/embedding-network.pb.h"
-#include "common/simple-adder.h"
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-namespace {
-
-using testing::ElementsAreArray;
-
-class TestingEmbeddingNetwork : public EmbeddingNetwork {
- public:
- using EmbeddingNetwork::EmbeddingNetwork;
- using EmbeddingNetwork::FinishComputeFinalScoresInternal;
-};
-
-void DiagonalAndBias3x3(int diagonal_value, int bias_value,
- MatrixParams* weights, MatrixParams* bias) {
- weights->set_rows(3);
- weights->set_cols(3);
- weights->add_value(diagonal_value);
- weights->add_value(0);
- weights->add_value(0);
- weights->add_value(0);
- weights->add_value(diagonal_value);
- weights->add_value(0);
- weights->add_value(0);
- weights->add_value(0);
- weights->add_value(diagonal_value);
-
- bias->set_rows(3);
- bias->set_cols(1);
- bias->add_value(bias_value);
- bias->add_value(bias_value);
- bias->add_value(bias_value);
-}
-
-TEST(EmbeddingNetworkTest, IdentityThroughMultipleLayers) {
- std::unique_ptr<EmbeddingNetworkProto> proto;
- proto.reset(new EmbeddingNetworkProto);
-
- // These layers should be an identity with bias.
- DiagonalAndBias3x3(/*diagonal_value=*/1, /*bias_value=*/1,
- proto->add_hidden(), proto->add_hidden_bias());
- DiagonalAndBias3x3(/*diagonal_value=*/1, /*bias_value=*/2,
- proto->add_hidden(), proto->add_hidden_bias());
- DiagonalAndBias3x3(/*diagonal_value=*/1, /*bias_value=*/3,
- proto->add_hidden(), proto->add_hidden_bias());
- DiagonalAndBias3x3(/*diagonal_value=*/1, /*bias_value=*/4,
- proto->add_hidden(), proto->add_hidden_bias());
- DiagonalAndBias3x3(/*diagonal_value=*/1, /*bias_value=*/5,
- proto->mutable_softmax(), proto->mutable_softmax_bias());
-
- EmbeddingNetworkParamsFromProto params(std::move(proto));
- TestingEmbeddingNetwork network(¶ms);
-
- std::vector<float> input({-2, -1, 0});
- std::vector<float> output;
- network.FinishComputeFinalScoresInternal<SimpleAdder>(
- VectorSpan<float>(input), &output);
-
- EXPECT_THAT(output, ElementsAreArray({14, 14, 15}));
-}
-
-} // namespace
-} // namespace nlp_core
-} // namespace libtextclassifier
diff --git a/common/feature-descriptors.h b/common/feature-descriptors.h
deleted file mode 100644
index 9aa6527..0000000
--- a/common/feature-descriptors.h
+++ /dev/null
@@ -1,154 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_COMMON_FEATURE_DESCRIPTORS_H_
-#define LIBTEXTCLASSIFIER_COMMON_FEATURE_DESCRIPTORS_H_
-
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "util/base/integral_types.h"
-#include "util/base/logging.h"
-#include "util/base/macros.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-// Named feature parameter.
-class Parameter {
- public:
- Parameter() {}
-
- void set_name(const std::string &name) { name_ = name; }
- const std::string &name() const { return name_; }
-
- void set_value(const std::string &value) { value_ = value; }
- const std::string &value() const { return value_; }
-
- private:
- std::string name_;
- std::string value_;
-};
-
-// Descriptor for a feature function. Used to store the results of parsing one
-// feature function.
-class FeatureFunctionDescriptor {
- public:
- FeatureFunctionDescriptor() {}
-
- // Accessors for the feature function type. The function type is the string
- // that the feature extractor code is registered under.
- void set_type(const std::string &type) { type_ = type; }
- bool has_type() const { return !type_.empty(); }
- const std::string &type() const { return type_; }
-
- // Accessors for the feature function name. The function name (if available)
- // is used for some log messages. Otherwise, a more precise, but also more
- // verbose name based on the feature specification is used.
- void set_name(const std::string &name) { name_ = name; }
- bool has_name() const { return !name_.empty(); }
- const std::string &name() { return name_; }
-
- // Accessors for the default (name-less) parameter.
- void set_argument(int32 argument) { argument_ = argument; }
- bool has_argument() const {
- // If argument has not been specified, clients should treat it as 0. This
- // makes the test below correct, without having a separate has_argument_
- // bool field.
- return argument_ != 0;
- }
- int32 argument() const { return argument_; }
-
- // Accessors for the named parameters.
- Parameter *add_parameter() {
- parameters_.emplace_back();
- return &(parameters_.back());
- }
- int parameter_size() const { return parameters_.size(); }
- const Parameter ¶meter(int i) const {
- TC_DCHECK((i >= 0) && (i < parameter_size()));
- return parameters_[i];
- }
-
- // Accessors for the sub (i.e., nested) features. Nested features: as in
- // offset(1).label.
- FeatureFunctionDescriptor *add_feature() {
- sub_features_.emplace_back(new FeatureFunctionDescriptor());
- return sub_features_.back().get();
- }
- int feature_size() const { return sub_features_.size(); }
- const FeatureFunctionDescriptor &feature(int i) const {
- TC_DCHECK((i >= 0) && (i < feature_size()));
- return *(sub_features_[i].get());
- }
- FeatureFunctionDescriptor *mutable_feature(int i) {
- TC_DCHECK((i >= 0) && (i < feature_size()));
- return sub_features_[i].get();
- }
-
- private:
- // See comments for set_type().
- std::string type_;
-
- // See comments for set_name().
- std::string name_;
-
- // See comments for set_argument().
- int32 argument_ = 0;
-
- // See comemnts for add_parameter().
- std::vector<Parameter> parameters_;
-
- // See comments for add_feature().
- std::vector<std::unique_ptr<FeatureFunctionDescriptor>> sub_features_;
-
- TC_DISALLOW_COPY_AND_ASSIGN(FeatureFunctionDescriptor);
-};
-
-// List of FeatureFunctionDescriptors. Used to store the result of parsing the
-// spec for several feature functions.
-class FeatureExtractorDescriptor {
- public:
- FeatureExtractorDescriptor() {}
-
- int feature_size() const { return features_.size(); }
-
- FeatureFunctionDescriptor *add_feature() {
- features_.emplace_back(new FeatureFunctionDescriptor());
- return features_.back().get();
- }
-
- const FeatureFunctionDescriptor &feature(int i) const {
- TC_DCHECK((i >= 0) && (i < feature_size()));
- return *(features_[i].get());
- }
-
- FeatureFunctionDescriptor *mutable_feature(int i) {
- TC_DCHECK((i >= 0) && (i < feature_size()));
- return features_[i].get();
- }
-
- private:
- std::vector<std::unique_ptr<FeatureFunctionDescriptor>> features_;
-
- TC_DISALLOW_COPY_AND_ASSIGN(FeatureExtractorDescriptor);
-};
-
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_COMMON_FEATURE_DESCRIPTORS_H_
diff --git a/common/feature-extractor.cc b/common/feature-extractor.cc
deleted file mode 100644
index 12de46d..0000000
--- a/common/feature-extractor.cc
+++ /dev/null
@@ -1,152 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "common/feature-extractor.h"
-
-#include "common/feature-types.h"
-#include "common/fml-parser.h"
-#include "util/base/integral_types.h"
-#include "util/base/logging.h"
-#include "util/gtl/stl_util.h"
-#include "util/strings/numbers.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-constexpr FeatureValue GenericFeatureFunction::kNone;
-
-GenericFeatureExtractor::GenericFeatureExtractor() {}
-
-GenericFeatureExtractor::~GenericFeatureExtractor() {}
-
-bool GenericFeatureExtractor::Parse(const std::string &source) {
- // Parse feature specification into descriptor.
- FMLParser parser;
- if (!parser.Parse(source, mutable_descriptor())) return false;
-
- // Initialize feature extractor from descriptor.
- if (!InitializeFeatureFunctions()) return false;
- return true;
-}
-
-bool GenericFeatureExtractor::InitializeFeatureTypes() {
- // Register all feature types.
- GetFeatureTypes(&feature_types_);
- for (size_t i = 0; i < feature_types_.size(); ++i) {
- FeatureType *ft = feature_types_[i];
- ft->set_base(i);
-
- // Check for feature space overflow.
- double domain_size = ft->GetDomainSize();
- if (domain_size < 0) {
- TC_LOG(ERROR) << "Illegal domain size for feature " << ft->name() << ": "
- << domain_size;
- return false;
- }
- }
- return true;
-}
-
-FeatureValue GenericFeatureExtractor::GetDomainSize() const {
- // Domain size of the set of features is equal to:
- // [largest domain size of any feature types] * [number of feature types]
- FeatureValue max_feature_type_dsize = 0;
- for (size_t i = 0; i < feature_types_.size(); ++i) {
- FeatureType *ft = feature_types_[i];
- const FeatureValue feature_type_dsize = ft->GetDomainSize();
- if (feature_type_dsize > max_feature_type_dsize) {
- max_feature_type_dsize = feature_type_dsize;
- }
- }
-
- return max_feature_type_dsize * feature_types_.size();
-}
-
-std::string GenericFeatureFunction::GetParameter(
- const std::string &name) const {
- // Find named parameter in feature descriptor.
- for (int i = 0; i < descriptor_->parameter_size(); ++i) {
- if (name == descriptor_->parameter(i).name()) {
- return descriptor_->parameter(i).value();
- }
- }
- return "";
-}
-
-GenericFeatureFunction::GenericFeatureFunction() {}
-
-GenericFeatureFunction::~GenericFeatureFunction() { delete feature_type_; }
-
-int GenericFeatureFunction::GetIntParameter(const std::string &name,
- int default_value) const {
- int32 parsed_value = default_value;
- std::string value = GetParameter(name);
- if (!value.empty()) {
- if (!ParseInt32(value.c_str(), &parsed_value)) {
- // A parameter value has been specified, but it can't be parsed as an int.
- // We don't crash: instead, we long an error and return the default value.
- TC_LOG(ERROR) << "Value of param " << name << " is not an int: " << value;
- }
- }
- return parsed_value;
-}
-
-bool GenericFeatureFunction::GetBoolParameter(const std::string &name,
- bool default_value) const {
- std::string value = GetParameter(name);
- if (value.empty()) return default_value;
- if (value == "true") return true;
- if (value == "false") return false;
- TC_LOG(ERROR) << "Illegal value '" << value << "' for bool parameter '"
- << name << "'"
- << " will assume default " << default_value;
- return default_value;
-}
-
-void GenericFeatureFunction::GetFeatureTypes(
- std::vector<FeatureType *> *types) const {
- if (feature_type_ != nullptr) types->push_back(feature_type_);
-}
-
-FeatureType *GenericFeatureFunction::GetFeatureType() const {
- // If a single feature type has been registered return it.
- if (feature_type_ != nullptr) return feature_type_;
-
- // Get feature types for function.
- std::vector<FeatureType *> types;
- GetFeatureTypes(&types);
-
- // If there is exactly one feature type return this, else return null.
- if (types.size() == 1) return types[0];
- return nullptr;
-}
-
-std::string GenericFeatureFunction::name() const {
- std::string output;
- if (descriptor_->name().empty()) {
- if (!prefix_.empty()) {
- output.append(prefix_);
- output.append(".");
- }
- ToFML(*descriptor_, &output);
- } else {
- output = descriptor_->name();
- }
- return output;
-}
-
-} // namespace nlp_core
-} // namespace libtextclassifier
diff --git a/common/feature-extractor.h b/common/feature-extractor.h
deleted file mode 100644
index bdba609..0000000
--- a/common/feature-extractor.h
+++ /dev/null
@@ -1,665 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Generic feature extractor for extracting features from objects. The feature
-// extractor can be used for extracting features from any object. The feature
-// extractor and feature function classes are template classes that have to
-// be instantiated for extracting feature from a specific object type.
-//
-// A feature extractor consists of a hierarchy of feature functions. Each
-// feature function extracts one or more feature type and value pairs from the
-// object.
-//
-// The feature extractor has a modular design where new feature functions can be
-// registered as components. The feature extractor is initialized from a
-// descriptor represented by a protocol buffer. The feature extractor can also
-// be initialized from a text-based source specification of the feature
-// extractor. Feature specification parsers can be added as components. By
-// default the feature extractor can be read from an ASCII protocol buffer or in
-// a simple feature modeling language (fml).
-
-// A feature function is invoked with a focus. Nested feature function can be
-// invoked with another focus determined by the parent feature function.
-
-#ifndef LIBTEXTCLASSIFIER_COMMON_FEATURE_EXTRACTOR_H_
-#define LIBTEXTCLASSIFIER_COMMON_FEATURE_EXTRACTOR_H_
-
-#include <stddef.h>
-
-#include <string>
-#include <vector>
-
-#include "common/feature-descriptors.h"
-#include "common/feature-types.h"
-#include "common/fml-parser.h"
-#include "common/registry.h"
-#include "common/task-context.h"
-#include "common/workspace.h"
-#include "util/base/integral_types.h"
-#include "util/base/logging.h"
-#include "util/base/macros.h"
-#include "util/gtl/stl_util.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-typedef int64 Predicate;
-typedef Predicate FeatureValue;
-
-// A union used to represent discrete and continuous feature values.
-union FloatFeatureValue {
- public:
- explicit FloatFeatureValue(FeatureValue v) : discrete_value(v) {}
- FloatFeatureValue(uint32 i, float w) : id(i), weight(w) {}
- FeatureValue discrete_value;
- struct {
- uint32 id;
- float weight;
- };
-};
-
-// A feature vector contains feature type and value pairs.
-class FeatureVector {
- public:
- FeatureVector() {}
-
- // Adds feature type and value pair to feature vector.
- void add(FeatureType *type, FeatureValue value) {
- features_.emplace_back(type, value);
- }
-
- // Removes all elements from the feature vector.
- void clear() { features_.clear(); }
-
- // Returns the number of elements in the feature vector.
- int size() const { return features_.size(); }
-
- // Reserves space in the underlying feature vector.
- void reserve(int n) { features_.reserve(n); }
-
- // Returns feature type for an element in the feature vector.
- FeatureType *type(int index) const { return features_[index].type; }
-
- // Returns feature value for an element in the feature vector.
- FeatureValue value(int index) const { return features_[index].value; }
-
- private:
- // Structure for holding feature type and value pairs.
- struct Element {
- Element() : type(nullptr), value(-1) {}
- Element(FeatureType *t, FeatureValue v) : type(t), value(v) {}
-
- FeatureType *type;
- FeatureValue value;
- };
-
- // Array for storing feature vector elements.
- std::vector<Element> features_;
-
- TC_DISALLOW_COPY_AND_ASSIGN(FeatureVector);
-};
-
-// The generic feature extractor is the type-independent part of a feature
-// extractor. This holds the descriptor for the feature extractor and the
-// collection of feature types used in the feature extractor. The feature
-// types are not available until FeatureExtractor<>::Init() has been called.
-class GenericFeatureExtractor {
- public:
- GenericFeatureExtractor();
- virtual ~GenericFeatureExtractor();
-
- // Initializes the feature extractor from an FML string specification. For
- // the FML specification grammar, see fml-parser.h.
- //
- // Returns true on success, false on syntax error.
- bool Parse(const std::string &source);
-
- // Returns the feature extractor descriptor.
- const FeatureExtractorDescriptor &descriptor() const { return descriptor_; }
- FeatureExtractorDescriptor *mutable_descriptor() { return &descriptor_; }
-
- // Returns the number of feature types in the feature extractor. Invalid
- // before Init() has been called.
- int feature_types() const { return feature_types_.size(); }
-
- // Returns a feature type used in the extractor. Invalid before Init() has
- // been called.
- const FeatureType *feature_type(int index) const {
- return feature_types_[index];
- }
-
- // Returns the feature domain size of this feature extractor.
- // NOTE: The way that domain size is calculated is, for some, unintuitive. It
- // is the largest domain size of any feature type.
- FeatureValue GetDomainSize() const;
-
- protected:
- // Initializes the feature types used by the extractor. Called from
- // FeatureExtractor<>::Init().
- //
- // Returns true on success, false on error.
- bool InitializeFeatureTypes();
-
- private:
- // Initializes the top-level feature functions.
- virtual bool InitializeFeatureFunctions() = 0;
-
- // Returns all feature types used by the extractor. The feature types are
- // added to the result array.
- virtual void GetFeatureTypes(std::vector<FeatureType *> *types) const = 0;
-
- // Descriptor for the feature extractor. This is a protocol buffer that
- // contains all the information about the feature extractor. The feature
- // functions are initialized from the information in the descriptor.
- FeatureExtractorDescriptor descriptor_;
-
- // All feature types used by the feature extractor. The collection of all the
- // feature types describes the feature space of the feature set produced by
- // the feature extractor. Not owned.
- std::vector<FeatureType *> feature_types_;
-
- TC_DISALLOW_COPY_AND_ASSIGN(GenericFeatureExtractor);
-};
-
-// The generic feature function is the type-independent part of a feature
-// function. Each feature function is associated with the descriptor that it is
-// instantiated from. The feature types associated with this feature function
-// will be established by the time FeatureExtractor<>::Init() completes.
-class GenericFeatureFunction {
- public:
- // A feature value that represents the absence of a value.
- static constexpr FeatureValue kNone = -1;
-
- GenericFeatureFunction();
- virtual ~GenericFeatureFunction();
-
- // Sets up the feature function. NB: FeatureTypes of nested functions are not
- // guaranteed to be available until Init().
- //
- // Returns true on success, false on error.
- virtual bool Setup(TaskContext *context) { return true; }
-
- // Initializes the feature function. NB: The FeatureType of this function must
- // be established when this method completes.
- //
- // Returns true on success, false on error.
- virtual bool Init(TaskContext *context) { return true; }
-
- // Requests workspaces from a registry to obtain indices into a WorkspaceSet
- // for any Workspace objects used by this feature function. NB: This will be
- // called after Init(), so it can depend on resources and arguments.
- virtual void RequestWorkspaces(WorkspaceRegistry *registry) {}
-
- // Appends the feature types produced by the feature function to types. The
- // default implementation appends feature_type(), if non-null. Invalid
- // before Init() has been called.
- virtual void GetFeatureTypes(std::vector<FeatureType *> *types) const;
-
- // Returns the feature type for feature produced by this feature function. If
- // the feature function produces features of different types this returns
- // null. Invalid before Init() has been called.
- virtual FeatureType *GetFeatureType() const;
-
- // Returns the name of the registry used for creating the feature function.
- // This can be used for checking if two feature functions are of the same
- // kind.
- virtual const char *RegistryName() const = 0;
-
- // Returns the value of a named parameter from the feature function
- // descriptor. Returns empty string ("") if parameter is not found.
- std::string GetParameter(const std::string &name) const;
-
- // Returns the int value of a named parameter from the feature function
- // descriptor. Returns default_value if the parameter is not found or if its
- // value can't be parsed as an int.
- int GetIntParameter(const std::string &name, int default_value) const;
-
- // Returns the bool value of a named parameter from the feature function
- // descriptor. Returns default_value if the parameter is not found or if its
- // value is not "true" or "false".
- bool GetBoolParameter(const std::string &name, bool default_value) const;
-
- // Returns the FML function description for the feature function, i.e. the
- // name and parameters without the nested features.
- std::string FunctionName() const {
- std::string output;
- ToFMLFunction(*descriptor_, &output);
- return output;
- }
-
- // Returns the prefix for nested feature functions. This is the prefix of this
- // feature function concatenated with the feature function name.
- std::string SubPrefix() const {
- return prefix_.empty() ? FunctionName() : prefix_ + "." + FunctionName();
- }
-
- // Returns/sets the feature extractor this function belongs to.
- GenericFeatureExtractor *extractor() const { return extractor_; }
- void set_extractor(GenericFeatureExtractor *extractor) {
- extractor_ = extractor;
- }
-
- // Returns/sets the feature function descriptor.
- FeatureFunctionDescriptor *descriptor() const { return descriptor_; }
- void set_descriptor(FeatureFunctionDescriptor *descriptor) {
- descriptor_ = descriptor;
- }
-
- // Returns a descriptive name for the feature function. The name is taken from
- // the descriptor for the feature function. If the name is empty or the
- // feature function is a variable the name is the FML representation of the
- // feature, including the prefix.
- std::string name() const;
-
- // Returns the argument from the feature function descriptor. It defaults to
- // 0 if the argument has not been specified.
- int argument() const {
- return descriptor_->has_argument() ? descriptor_->argument() : 0;
- }
-
- // Returns/sets/clears function name prefix.
- const std::string &prefix() const { return prefix_; }
- void set_prefix(const std::string &prefix) { prefix_ = prefix; }
-
- protected:
- // Returns the feature type for single-type feature functions.
- FeatureType *feature_type() const { return feature_type_; }
-
- // Sets the feature type for single-type feature functions. This takes
- // ownership of feature_type. Can only be called once with a non-null
- // pointer.
- void set_feature_type(FeatureType *feature_type) {
- TC_DCHECK_NE(feature_type, nullptr);
- feature_type_ = feature_type;
- }
-
- private:
- // Feature extractor this feature function belongs to. Not owned.
- GenericFeatureExtractor *extractor_ = nullptr;
-
- // Descriptor for feature function. Not owned.
- FeatureFunctionDescriptor *descriptor_ = nullptr;
-
- // Feature type for features produced by this feature function. If the
- // feature function produces features of multiple feature types this is null
- // and the feature function must return it's feature types in
- // GetFeatureTypes(). Owned.
- FeatureType *feature_type_ = nullptr;
-
- // Prefix used for sub-feature types of this function.
- std::string prefix_;
-};
-
-// Feature function that can extract features from an object. Templated on
-// two type arguments:
-//
-// OBJ: The "object" from which features are extracted; e.g., a sentence. This
-// should be a plain type, rather than a reference or pointer.
-//
-// ARGS: A set of 0 or more types that are used to "index" into some part of the
-// object that should be extracted, e.g. an int token index for a sentence
-// object. This should not be a reference type.
-template <class OBJ, class... ARGS>
-class FeatureFunction
- : public GenericFeatureFunction,
- public RegisterableClass<FeatureFunction<OBJ, ARGS...> > {
- public:
- using Self = FeatureFunction<OBJ, ARGS...>;
-
- // Preprocesses the object. This will be called prior to calling Evaluate()
- // or Compute() on that object.
- virtual void Preprocess(WorkspaceSet *workspaces, OBJ *object) const {}
-
- // Appends features computed from the object and focus to the result. The
- // default implementation delegates to Compute(), adding a single value if
- // available. Multi-valued feature functions must override this method.
- virtual void Evaluate(const WorkspaceSet &workspaces, const OBJ &object,
- ARGS... args, FeatureVector *result) const {
- FeatureValue value = Compute(workspaces, object, args..., result);
- if (value != kNone) result->add(feature_type(), value);
- }
-
- // Returns a feature value computed from the object and focus, or kNone if no
- // value is computed. Single-valued feature functions only need to override
- // this method.
- virtual FeatureValue Compute(const WorkspaceSet &workspaces,
- const OBJ &object, ARGS... args,
- const FeatureVector *fv) const {
- return kNone;
- }
-
- // Instantiates a new feature function in a feature extractor from a feature
- // descriptor.
- static Self *Instantiate(GenericFeatureExtractor *extractor,
- FeatureFunctionDescriptor *fd,
- const std::string &prefix) {
- Self *f = Self::Create(fd->type());
- if (f != nullptr) {
- f->set_extractor(extractor);
- f->set_descriptor(fd);
- f->set_prefix(prefix);
- }
- return f;
- }
-
- // Returns the name of the registry for the feature function.
- const char *RegistryName() const override { return Self::registry()->name(); }
-
- private:
- // Special feature function class for resolving variable references. The type
- // of the feature function is used for resolving the variable reference. When
- // evaluated it will either get the feature value(s) from the variable portion
- // of the feature vector, if present, or otherwise it will call the referenced
- // feature extractor function directly to extract the feature(s).
- class Reference;
-};
-
-// Base class for features with nested feature functions. The nested functions
-// are of type NES, which may be different from the type of the parent function.
-// NB: NestedFeatureFunction will ensure that all initialization of nested
-// functions takes place during Setup() and Init() -- after the nested features
-// are initialized, the parent feature is initialized via SetupNested() and
-// InitNested(). Alternatively, a derived classes that overrides Setup() and
-// Init() directly should call Parent::Setup(), Parent::Init(), etc. first.
-//
-// Note: NestedFeatureFunction cannot know how to call Preprocess, Evaluate, or
-// Compute, since the nested functions may be of a different type.
-template <class NES, class OBJ, class... ARGS>
-class NestedFeatureFunction : public FeatureFunction<OBJ, ARGS...> {
- public:
- using Parent = NestedFeatureFunction<NES, OBJ, ARGS...>;
-
- // Clean up nested functions.
- ~NestedFeatureFunction() override {
- // Fully qualified class name, to avoid an ambiguity error when building for
- // Android.
- ::libtextclassifier::STLDeleteElements(&nested_);
- }
-
- // By default, just appends the nested feature types.
- void GetFeatureTypes(std::vector<FeatureType *> *types) const override {
- // It's odd if a NestedFeatureFunction does not have anything nested inside
- // it, so we crash in debug mode. Still, nothing should crash in prod mode.
- TC_DCHECK(!this->nested().empty())
- << "Nested features require nested features to be defined.";
- for (auto *function : nested_) function->GetFeatureTypes(types);
- }
-
- // Sets up the nested features.
- bool Setup(TaskContext *context) override {
- bool success = CreateNested(this->extractor(), this->descriptor(), &nested_,
- this->SubPrefix());
- if (!success) {
- return false;
- }
- for (auto *function : nested_) {
- if (!function->Setup(context)) return false;
- }
- if (!SetupNested(context)) {
- return false;
- }
- return true;
- }
-
- // Sets up this NestedFeatureFunction specifically.
- virtual bool SetupNested(TaskContext *context) { return true; }
-
- // Initializes the nested features.
- bool Init(TaskContext *context) override {
- for (auto *function : nested_) {
- if (!function->Init(context)) return false;
- }
- if (!InitNested(context)) return false;
- return true;
- }
-
- // Initializes this NestedFeatureFunction specifically.
- virtual bool InitNested(TaskContext *context) { return true; }
-
- // Gets all the workspaces needed for the nested functions.
- void RequestWorkspaces(WorkspaceRegistry *registry) override {
- for (auto *function : nested_) function->RequestWorkspaces(registry);
- }
-
- // Returns the list of nested feature functions.
- const std::vector<NES *> &nested() const { return nested_; }
-
- // Instantiates nested feature functions for a feature function. Creates and
- // initializes one feature function for each sub-descriptor in the feature
- // descriptor.
- static bool CreateNested(GenericFeatureExtractor *extractor,
- FeatureFunctionDescriptor *fd,
- std::vector<NES *> *functions,
- const std::string &prefix) {
- for (int i = 0; i < fd->feature_size(); ++i) {
- FeatureFunctionDescriptor *sub = fd->mutable_feature(i);
- NES *f = NES::Instantiate(extractor, sub, prefix);
- if (f == nullptr) {
- return false;
- }
- functions->push_back(f);
- }
- return true;
- }
-
- protected:
- // The nested feature functions, if any, in order of declaration in the
- // feature descriptor. Owned.
- std::vector<NES *> nested_;
-};
-
-// Base class for a nested feature function that takes nested features with the
-// same signature as these features, i.e. a meta feature. For this class, we can
-// provide preprocessing of the nested features.
-template <class OBJ, class... ARGS>
-class MetaFeatureFunction
- : public NestedFeatureFunction<FeatureFunction<OBJ, ARGS...>, OBJ,
- ARGS...> {
- public:
- // Preprocesses using the nested features.
- void Preprocess(WorkspaceSet *workspaces, OBJ *object) const override {
- for (auto *function : this->nested_) {
- function->Preprocess(workspaces, object);
- }
- }
-};
-
-// Template for a special type of locator: The locator of type
-// FeatureFunction<OBJ, ARGS...> calls nested functions of type
-// FeatureFunction<OBJ, IDX, ARGS...>, where the derived class DER is
-// responsible for translating by providing the following:
-//
-// // Gets the new additional focus.
-// IDX GetFocus(const WorkspaceSet &workspaces, const OBJ &object);
-//
-// This is useful to e.g. add a token focus to a parser state based on some
-// desired property of that state.
-template <class DER, class OBJ, class IDX, class... ARGS>
-class FeatureAddFocusLocator
- : public NestedFeatureFunction<FeatureFunction<OBJ, IDX, ARGS...>, OBJ,
- ARGS...> {
- public:
- void Preprocess(WorkspaceSet *workspaces, OBJ *object) const override {
- for (auto *function : this->nested_) {
- function->Preprocess(workspaces, object);
- }
- }
-
- void Evaluate(const WorkspaceSet &workspaces, const OBJ &object, ARGS... args,
- FeatureVector *result) const override {
- IDX focus =
- static_cast<const DER *>(this)->GetFocus(workspaces, object, args...);
- for (auto *function : this->nested()) {
- function->Evaluate(workspaces, object, focus, args..., result);
- }
- }
-
- // Returns the first nested feature's computed value.
- FeatureValue Compute(const WorkspaceSet &workspaces, const OBJ &object,
- ARGS... args,
- const FeatureVector *result) const override {
- IDX focus =
- static_cast<const DER *>(this)->GetFocus(workspaces, object, args...);
- return this->nested()[0]->Compute(workspaces, object, focus, args...,
- result);
- }
-};
-
-// CRTP feature locator class. This is a meta feature that modifies ARGS and
-// then calls the nested feature functions with the modified ARGS. Note that in
-// order for this template to work correctly, all of ARGS must be types for
-// which the reference operator & can be interpreted as a pointer to the
-// argument. The derived class DER must implement the UpdateFocus method which
-// takes pointers to the ARGS arguments:
-//
-// // Updates the current arguments.
-// void UpdateArgs(const OBJ &object, ARGS *...args) const;
-template <class DER, class OBJ, class... ARGS>
-class FeatureLocator : public MetaFeatureFunction<OBJ, ARGS...> {
- public:
- // Feature locators have an additional check that there is no intrinsic type,
- // but only in debug mode: having an intrinsic type here is odd, but not
- // enough to motive a crash in prod.
- void GetFeatureTypes(std::vector<FeatureType *> *types) const override {
- TC_DCHECK_EQ(this->feature_type(), nullptr)
- << "FeatureLocators should not have an intrinsic type.";
- MetaFeatureFunction<OBJ, ARGS...>::GetFeatureTypes(types);
- }
-
- // Evaluates the locator.
- void Evaluate(const WorkspaceSet &workspaces, const OBJ &object, ARGS... args,
- FeatureVector *result) const override {
- static_cast<const DER *>(this)->UpdateArgs(workspaces, object, &args...);
- for (auto *function : this->nested()) {
- function->Evaluate(workspaces, object, args..., result);
- }
- }
-
- // Returns the first nested feature's computed value.
- FeatureValue Compute(const WorkspaceSet &workspaces, const OBJ &object,
- ARGS... args,
- const FeatureVector *result) const override {
- static_cast<const DER *>(this)->UpdateArgs(workspaces, object, &args...);
- return this->nested()[0]->Compute(workspaces, object, args..., result);
- }
-};
-
-// Feature extractor for extracting features from objects of a certain class.
-// Template type parameters are as defined for FeatureFunction.
-template <class OBJ, class... ARGS>
-class FeatureExtractor : public GenericFeatureExtractor {
- public:
- // Feature function type for top-level functions in the feature extractor.
- typedef FeatureFunction<OBJ, ARGS...> Function;
- typedef FeatureExtractor<OBJ, ARGS...> Self;
-
- // Feature locator type for the feature extractor.
- template <class DER>
- using Locator = FeatureLocator<DER, OBJ, ARGS...>;
-
- // Initializes feature extractor.
- FeatureExtractor() {}
-
- ~FeatureExtractor() override {
- // Fully qualified class name, to avoid an ambiguity error when building for
- // Android.
- ::libtextclassifier::STLDeleteElements(&functions_);
- }
-
- // Sets up the feature extractor. Note that only top-level functions exist
- // until Setup() is called. This does not take ownership over the context,
- // which must outlive this.
- bool Setup(TaskContext *context) {
- for (Function *function : functions_) {
- if (!function->Setup(context)) return false;
- }
- return true;
- }
-
- // Initializes the feature extractor. Must be called after Setup(). This
- // does not take ownership over the context, which must outlive this.
- bool Init(TaskContext *context) {
- for (Function *function : functions_) {
- if (!function->Init(context)) return false;
- }
- if (!this->InitializeFeatureTypes()) {
- return false;
- }
- return true;
- }
-
- // Requests workspaces from the registry. Must be called after Init(), and
- // before Preprocess(). Does not take ownership over registry. This should be
- // the same registry used to initialize the WorkspaceSet used in Preprocess()
- // and ExtractFeatures(). NB: This is a different ordering from that used in
- // SentenceFeatureRepresentation style feature computation.
- void RequestWorkspaces(WorkspaceRegistry *registry) {
- for (auto *function : functions_) function->RequestWorkspaces(registry);
- }
-
- // Preprocesses the object using feature functions for the phase. Must be
- // called before any calls to ExtractFeatures() on that object and phase.
- void Preprocess(WorkspaceSet *workspaces, OBJ *object) const {
- for (Function *function : functions_) {
- function->Preprocess(workspaces, object);
- }
- }
-
- // Extracts features from an object with a focus. This invokes all the
- // top-level feature functions in the feature extractor. Only feature
- // functions belonging to the specified phase are invoked.
- void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &object,
- ARGS... args, FeatureVector *result) const {
- result->reserve(this->feature_types());
-
- // Extract features.
- for (int i = 0; i < functions_.size(); ++i) {
- functions_[i]->Evaluate(workspaces, object, args..., result);
- }
- }
-
- private:
- // Creates and initializes all feature functions in the feature extractor.
- bool InitializeFeatureFunctions() override {
- // Create all top-level feature functions.
- for (int i = 0; i < descriptor().feature_size(); ++i) {
- FeatureFunctionDescriptor *fd = mutable_descriptor()->mutable_feature(i);
- Function *function = Function::Instantiate(this, fd, "");
- if (function == nullptr) return false;
- functions_.push_back(function);
- }
- return true;
- }
-
- // Collect all feature types used in the feature extractor.
- void GetFeatureTypes(std::vector<FeatureType *> *types) const override {
- for (Function *function : functions_) {
- function->GetFeatureTypes(types);
- }
- }
-
- // Top-level feature functions (and variables) in the feature extractor.
- // Owned. INVARIANT: contains only non-null pointers.
- std::vector<Function *> functions_;
-};
-
-#define REGISTER_FEATURE_FUNCTION(base, name, component) \
- REGISTER_CLASS_COMPONENT(base, name, component)
-
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_COMMON_FEATURE_EXTRACTOR_H_
diff --git a/common/feature-types.h b/common/feature-types.h
deleted file mode 100644
index 92814d9..0000000
--- a/common/feature-types.h
+++ /dev/null
@@ -1,189 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Common feature types for parser components.
-
-#ifndef LIBTEXTCLASSIFIER_COMMON_FEATURE_TYPES_H_
-#define LIBTEXTCLASSIFIER_COMMON_FEATURE_TYPES_H_
-
-#include <algorithm>
-#include <map>
-#include <string>
-#include <utility>
-
-#include "util/base/integral_types.h"
-#include "util/base/logging.h"
-#include "util/strings/numbers.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-// TODO(djweiss) Clean this up as well.
-// Use the same type for feature values as is used for predicated.
-typedef int64 Predicate;
-typedef Predicate FeatureValue;
-
-// Each feature value in a feature vector has a feature type. The feature type
-// is used for converting feature type and value pairs to predicate values. The
-// feature type can also return names for feature values and calculate the size
-// of the feature value domain. The FeatureType class is abstract and must be
-// specialized for the concrete feature types.
-class FeatureType {
- public:
- // Initializes a feature type.
- explicit FeatureType(const std::string &name)
- : name_(name), base_(0),
- is_continuous_(name.find("continuous") != std::string::npos) {
- }
-
- virtual ~FeatureType() {}
-
- // Converts a feature value to a name.
- virtual std::string GetFeatureValueName(FeatureValue value) const = 0;
-
- // Returns the size of the feature values domain.
- virtual int64 GetDomainSize() const = 0;
-
- // Returns the feature type name.
- const std::string &name() const { return name_; }
-
- Predicate base() const { return base_; }
- void set_base(Predicate base) { base_ = base; }
-
- // Returns true iff this feature is continuous; see FloatFeatureValue.
- bool is_continuous() const { return is_continuous_; }
-
- private:
- // Feature type name.
- std::string name_;
-
- // "Base" feature value: i.e. a "slot" in a global ordering of features.
- Predicate base_;
-
- // See doc for is_continuous().
- bool is_continuous_;
-};
-
-// Feature type that is defined using an explicit map from FeatureValue to
-// std::string values. This can reduce some of the boilerplate when defining
-// features that generate enum values. Example usage:
-//
-// class BeverageSizeFeature : public FeatureFunction<Beverage>
-// enum FeatureValue { SMALL, MEDIUM, LARGE }; // values for this feature
-// void Init(TaskContext *context) override {
-// set_feature_type(new EnumFeatureType("beverage_size",
-// {{SMALL, "SMALL"}, {MEDIUM, "MEDIUM"}, {LARGE, "LARGE"}});
-// }
-// [...]
-// };
-class EnumFeatureType : public FeatureType {
- public:
- EnumFeatureType(const std::string &name,
- const std::map<FeatureValue, std::string> &value_names)
- : FeatureType(name), value_names_(value_names) {
- for (const auto &pair : value_names) {
- TC_DCHECK_GE(pair.first, 0)
- << "Invalid feature value: " << pair.first << ", " << pair.second;
- domain_size_ = std::max(domain_size_, pair.first + 1);
- }
- }
-
- // Returns the feature name for a given feature value.
- std::string GetFeatureValueName(FeatureValue value) const override {
- auto it = value_names_.find(value);
- if (it == value_names_.end()) {
- TC_LOG(ERROR) << "Invalid feature value " << value << " for " << name();
- return "<INVALID>";
- }
- return it->second;
- }
-
- // Returns the number of possible values for this feature type. This is one
- // greater than the largest value in the value_names map.
- FeatureValue GetDomainSize() const override { return domain_size_; }
-
- protected:
- // Maximum possible value this feature could take.
- FeatureValue domain_size_ = 0;
-
- // Names of feature values.
- std::map<FeatureValue, std::string> value_names_;
-};
-
-// Feature type for binary features.
-class BinaryFeatureType : public FeatureType {
- public:
- BinaryFeatureType(const std::string &name, const std::string &off,
- const std::string &on)
- : FeatureType(name), off_(off), on_(on) {}
-
- // Returns the feature name for a given feature value.
- std::string GetFeatureValueName(FeatureValue value) const override {
- if (value == 0) return off_;
- if (value == 1) return on_;
- return "";
- }
-
- // Binary features always have two feature values.
- FeatureValue GetDomainSize() const override { return 2; }
-
- private:
- // Feature value names for on and off.
- std::string off_;
- std::string on_;
-};
-
-// Feature type for numeric features.
-class NumericFeatureType : public FeatureType {
- public:
- // Initializes numeric feature.
- NumericFeatureType(const std::string &name, FeatureValue size)
- : FeatureType(name), size_(size) {}
-
- // Returns numeric feature value.
- std::string GetFeatureValueName(FeatureValue value) const override {
- if (value < 0) return "";
- return IntToString(value);
- }
-
- // Returns the number of feature values.
- FeatureValue GetDomainSize() const override { return size_; }
-
- private:
- // The underlying size of the numeric feature.
- FeatureValue size_;
-};
-
-// Feature type for byte features, including an "outside" value.
-class ByteFeatureType : public NumericFeatureType {
- public:
- explicit ByteFeatureType(const std::string &name)
- : NumericFeatureType(name, 257) {}
-
- std::string GetFeatureValueName(FeatureValue value) const override {
- if (value == 256) {
- return "<NULL>";
- }
- std::string result;
- result += static_cast<char>(value);
- return result;
- }
-};
-
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_COMMON_FEATURE_TYPES_H_
diff --git a/common/file-utils.cc b/common/file-utils.cc
deleted file mode 100644
index 6ae4442..0000000
--- a/common/file-utils.cc
+++ /dev/null
@@ -1,76 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "common/file-utils.h"
-
-#include <fcntl.h>
-#include <stdio.h>
-#include <sys/stat.h>
-#include <sys/types.h>
-
-#include <fstream>
-#include <memory>
-#include <string>
-
-#include "util/base/logging.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-namespace file_utils {
-
-bool GetFileContent(const std::string &filename, std::string *content) {
- std::ifstream input_stream(filename, std::ifstream::binary);
- if (input_stream.fail()) {
- TC_LOG(INFO) << "Error opening " << filename;
- return false;
- }
-
- content->assign(
- std::istreambuf_iterator<char>(input_stream),
- std::istreambuf_iterator<char>());
-
- if (input_stream.fail()) {
- TC_LOG(ERROR) << "Error reading " << filename;
- return false;
- }
-
- TC_LOG(INFO) << "Successfully read " << filename;
- return true;
-}
-
-bool FileExists(const std::string &filename) {
- struct stat s = {0};
- if (!stat(filename.c_str(), &s)) {
- return s.st_mode & S_IFREG;
- } else {
- return false;
- }
-}
-
-bool DirectoryExists(const std::string &dirpath) {
- struct stat s = {0};
- if (!stat(dirpath.c_str(), &s)) {
- return s.st_mode & S_IFDIR;
- } else {
- return false;
- }
-}
-
-} // namespace file_utils
-
-} // namespace nlp_core
-} // namespace libtextclassifier
diff --git a/common/file-utils.h b/common/file-utils.h
deleted file mode 100644
index e2a60f2..0000000
--- a/common/file-utils.h
+++ /dev/null
@@ -1,91 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_COMMON_FILE_UTILS_H_
-#define LIBTEXTCLASSIFIER_COMMON_FILE_UTILS_H_
-
-#include <cstddef>
-#include <memory>
-#include <string>
-
-#include "common/config.h"
-
-#if PORTABLE_SAFT_MOBILE
-#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
-#endif
-
-#include "common/mmap.h"
-#include "util/strings/stringpiece.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-namespace file_utils {
-
-// Reads the entire content of a file into a string. Returns true on success,
-// false on error.
-bool GetFileContent(const std::string &filename, std::string *content);
-
-// Parses a proto from its serialized representation in memory. That
-// representation starts at address sp.data() and contains exactly sp.size()
-// bytes. Returns true on success, false otherwise.
-template<class Proto>
-bool ParseProtoFromMemory(StringPiece sp, Proto *proto) {
- if (!sp.data()) {
- // Avoid passing a nullptr to ArrayInputStream below.
- return false;
- }
-#if PORTABLE_SAFT_MOBILE
- ::google::protobuf::io::ArrayInputStream stream(sp.data(), sp.size());
- return proto->ParseFromZeroCopyStream(&stream);
-#else
-
- std::string data(sp.data(), sp.size());
- return proto->ParseFromString(data);
-#endif
-}
-
-// Parses a proto from a file. Returns true on success, false otherwise.
-//
-// Note: the entire content of the file should be the binary (not
-// human-readable) serialization of a protocol buffer.
-//
-// Note: when we compile for Android, the proto parsing methods need to know the
-// type of the message they are parsing. We use template polymorphism for that.
-template<class Proto>
-bool ReadProtoFromFile(const std::string &filename, Proto *proto) {
- ScopedMmap scoped_mmap(filename);
- const MmapHandle &handle = scoped_mmap.handle();
- if (!handle.ok()) {
- return false;
- }
- return ParseProtoFromMemory(handle.to_stringpiece(), proto);
-}
-
-// Returns true if filename is the name of an existing file, and false
-// otherwise.
-bool FileExists(const std::string &filename);
-
-// Returns true if dirpath is the path to an existing directory, and false
-// otherwise.
-bool DirectoryExists(const std::string &dirpath);
-
-} // namespace file_utils
-
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_COMMON_FILE_UTILS_H_
diff --git a/common/float16.h b/common/float16.h
deleted file mode 100644
index 8b52be3..0000000
--- a/common/float16.h
+++ /dev/null
@@ -1,64 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_COMMON_FLOAT16_H_
-#define LIBTEXTCLASSIFIER_COMMON_FLOAT16_H_
-
-#include "util/base/casts.h"
-#include "util/base/integral_types.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-// 16 bit encoding of a float. NOTE: can't be used directly for computation:
-// one first needs to convert it to a normal float, using Float16To32.
-//
-// Documentation copied from original file:
-//
-// Compact 16-bit encoding of floating point numbers. This
-// representation uses 1 bit for the sign, 8 bits for the exponent and
-// 7 bits for the mantissa. It is assumed that floats are in IEEE 754
-// format so a float16 is just bits 16-31 of a single precision float.
-//
-// NOTE: The IEEE floating point standard defines a float16 format that
-// is different than this format (it has fewer bits of exponent and more
-// bits of mantissa). We don't use that format here because conversion
-// to/from 32-bit floats is more complex for that format, and the
-// conversion for this format is very simple.
-//
-// <---------float16------------>
-// s e e e e e e e e f f f f f f f f f f f f f f f f f f f f f f f
-// <------------------------------float-------------------------->
-// 3 3 2 2 1 1 0
-// 1 0 3 2 5 4 0
-
-typedef uint16 float16;
-
-static inline float16 Float32To16(float f) {
- // Note that we just truncate the mantissa bits: we make no effort to
- // do any smarter rounding.
- return (bit_cast<uint32>(f) >> 16) & 0xffff;
-}
-
-static inline float Float16To32(float16 f) {
- // We fill in the new mantissa bits with 0, and don't do anything smarter.
- return bit_cast<float>(f << 16);
-}
-
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_COMMON_FLOAT16_H_
diff --git a/common/fml-parser.cc b/common/fml-parser.cc
deleted file mode 100644
index 2964671..0000000
--- a/common/fml-parser.cc
+++ /dev/null
@@ -1,329 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "common/fml-parser.h"
-
-#include <ctype.h>
-#include <string>
-
-#include "util/base/logging.h"
-#include "util/strings/numbers.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-namespace {
-inline bool IsValidCharAtStartOfIdentifier(char c) {
- return isalpha(c) || (c == '_') || (c == '/');
-}
-
-// Returns true iff character c can appear inside an identifier.
-inline bool IsValidCharInsideIdentifier(char c) {
- return isalnum(c) || (c == '_') || (c == '-') || (c == '/');
-}
-
-// Returns true iff character c can appear at the beginning of a number.
-inline bool IsValidCharAtStartOfNumber(char c) {
- return isdigit(c) || (c == '+') || (c == '-');
-}
-
-// Returns true iff character c can appear inside a number.
-inline bool IsValidCharInsideNumber(char c) {
- return isdigit(c) || (c == '.');
-}
-} // namespace
-
-bool FMLParser::Initialize(const std::string &source) {
- // Initialize parser state.
- source_ = source;
- current_ = source_.begin();
- item_start_ = line_start_ = current_;
- line_number_ = item_line_number_ = 1;
-
- // Read first input item.
- return NextItem();
-}
-
-void FMLParser::ReportError(const std::string &error_message) {
- const int position = item_start_ - line_start_ + 1;
- const std::string line(line_start_, current_);
-
- TC_LOG(ERROR) << "Error in feature model, line " << item_line_number_
- << ", position " << position << ": " << error_message
- << "\n " << line << " <--HERE";
-}
-
-void FMLParser::Next() {
- // Move to the next input character. If we are at a line break update line
- // number and line start position.
- if (CurrentChar() == '\n') {
- ++line_number_;
- ++current_;
- line_start_ = current_;
- } else {
- ++current_;
- }
-}
-
-bool FMLParser::NextItem() {
- // Skip white space and comments.
- while (!eos()) {
- if (CurrentChar() == '#') {
- // Skip comment.
- while (!eos() && CurrentChar() != '\n') Next();
- } else if (isspace(CurrentChar())) {
- // Skip whitespace.
- while (!eos() && isspace(CurrentChar())) Next();
- } else {
- break;
- }
- }
-
- // Record start position for next item.
- item_start_ = current_;
- item_line_number_ = line_number_;
-
- // Check for end of input.
- if (eos()) {
- item_type_ = END;
- return true;
- }
-
- // Parse number.
- if (IsValidCharAtStartOfNumber(CurrentChar())) {
- std::string::iterator start = current_;
- Next();
- while (!eos() && IsValidCharInsideNumber(CurrentChar())) Next();
- item_text_.assign(start, current_);
- item_type_ = NUMBER;
- return true;
- }
-
- // Parse std::string.
- if (CurrentChar() == '"') {
- Next();
- std::string::iterator start = current_;
- while (CurrentChar() != '"') {
- if (eos()) {
- ReportError("Unterminated string");
- return false;
- }
- Next();
- }
- item_text_.assign(start, current_);
- item_type_ = STRING;
- Next();
- return true;
- }
-
- // Parse identifier name.
- if (IsValidCharAtStartOfIdentifier(CurrentChar())) {
- std::string::iterator start = current_;
- while (!eos() && IsValidCharInsideIdentifier(CurrentChar())) {
- Next();
- }
- item_text_.assign(start, current_);
- item_type_ = NAME;
- return true;
- }
-
- // Single character item.
- item_type_ = CurrentChar();
- Next();
- return true;
-}
-
-bool FMLParser::Parse(const std::string &source,
- FeatureExtractorDescriptor *result) {
- // Initialize parser.
- if (!Initialize(source)) {
- return false;
- }
-
- while (item_type_ != END) {
- // Current item should be a feature name.
- if (item_type_ != NAME) {
- ReportError("Feature type name expected");
- return false;
- }
- std::string name = item_text_;
- if (!NextItem()) {
- return false;
- }
-
- // Parse feature.
- FeatureFunctionDescriptor *descriptor = result->add_feature();
- descriptor->set_type(name);
- if (!ParseFeature(descriptor)) {
- return false;
- }
- }
-
- return true;
-}
-
-bool FMLParser::ParseFeature(FeatureFunctionDescriptor *result) {
- // Parse argument and parameters.
- if (item_type_ == '(') {
- if (!NextItem()) return false;
- if (!ParseParameter(result)) return false;
- while (item_type_ == ',') {
- if (!NextItem()) return false;
- if (!ParseParameter(result)) return false;
- }
-
- if (item_type_ != ')') {
- ReportError(") expected");
- return false;
- }
- if (!NextItem()) return false;
- }
-
- // Parse feature name.
- if (item_type_ == ':') {
- if (!NextItem()) return false;
- if (item_type_ != NAME && item_type_ != STRING) {
- ReportError("Feature name expected");
- return false;
- }
- std::string name = item_text_;
- if (!NextItem()) return false;
-
- // Set feature name.
- result->set_name(name);
- }
-
- // Parse sub-features.
- if (item_type_ == '.') {
- // Parse dotted sub-feature.
- if (!NextItem()) return false;
- if (item_type_ != NAME) {
- ReportError("Feature type name expected");
- return false;
- }
- std::string type = item_text_;
- if (!NextItem()) return false;
-
- // Parse sub-feature.
- FeatureFunctionDescriptor *subfeature = result->add_feature();
- subfeature->set_type(type);
- if (!ParseFeature(subfeature)) return false;
- } else if (item_type_ == '{') {
- // Parse sub-feature block.
- if (!NextItem()) return false;
- while (item_type_ != '}') {
- if (item_type_ != NAME) {
- ReportError("Feature type name expected");
- return false;
- }
- std::string type = item_text_;
- if (!NextItem()) return false;
-
- // Parse sub-feature.
- FeatureFunctionDescriptor *subfeature = result->add_feature();
- subfeature->set_type(type);
- if (!ParseFeature(subfeature)) return false;
- }
- if (!NextItem()) return false;
- }
- return true;
-}
-
-bool FMLParser::ParseParameter(FeatureFunctionDescriptor *result) {
- if (item_type_ == NUMBER) {
- int32 argument;
- if (!ParseInt32(item_text_.c_str(), &argument)) {
- ReportError("Unable to parse number");
- return false;
- }
- if (!NextItem()) return false;
-
- // Set default argument for feature.
- result->set_argument(argument);
- } else if (item_type_ == NAME) {
- std::string name = item_text_;
- if (!NextItem()) return false;
- if (item_type_ != '=') {
- ReportError("= expected");
- return false;
- }
- if (!NextItem()) return false;
- if (item_type_ >= END) {
- ReportError("Parameter value expected");
- return false;
- }
- std::string value = item_text_;
- if (!NextItem()) return false;
-
- // Add parameter to feature.
- Parameter *parameter;
- parameter = result->add_parameter();
- parameter->set_name(name);
- parameter->set_value(value);
- } else {
- ReportError("Syntax error in parameter list");
- return false;
- }
- return true;
-}
-
-void ToFMLFunction(const FeatureFunctionDescriptor &function,
- std::string *output) {
- output->append(function.type());
- if (function.argument() != 0 || function.parameter_size() > 0) {
- output->append("(");
- bool first = true;
- if (function.argument() != 0) {
- output->append(IntToString(function.argument()));
- first = false;
- }
- for (int i = 0; i < function.parameter_size(); ++i) {
- if (!first) output->append(",");
- output->append(function.parameter(i).name());
- output->append("=");
- output->append("\"");
- output->append(function.parameter(i).value());
- output->append("\"");
- first = false;
- }
- output->append(")");
- }
-}
-
-void ToFML(const FeatureFunctionDescriptor &function, std::string *output) {
- ToFMLFunction(function, output);
- if (function.feature_size() == 1) {
- output->append(".");
- ToFML(function.feature(0), output);
- } else if (function.feature_size() > 1) {
- output->append(" { ");
- for (int i = 0; i < function.feature_size(); ++i) {
- if (i > 0) output->append(" ");
- ToFML(function.feature(i), output);
- }
- output->append(" } ");
- }
-}
-
-void ToFML(const FeatureExtractorDescriptor &extractor, std::string *output) {
- for (int i = 0; i < extractor.feature_size(); ++i) {
- ToFML(extractor.feature(i), output);
- output->append("\n");
- }
-}
-
-} // namespace nlp_core
-} // namespace libtextclassifier
diff --git a/common/fml-parser.h b/common/fml-parser.h
deleted file mode 100644
index b6b9da2..0000000
--- a/common/fml-parser.h
+++ /dev/null
@@ -1,147 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Feature modeling language (fml) parser.
-//
-// BNF grammar for fml:
-//
-// <feature model> ::= { <feature extractor> }
-//
-// <feature extractor> ::= <extractor spec> |
-// <extractor spec> '.' <feature extractor> |
-// <extractor spec> '{' { <feature extractor> } '}'
-//
-// <extractor spec> ::= <extractor type>
-// [ '(' <parameter list> ')' ]
-// [ ':' <extractor name> ]
-//
-// <parameter list> = ( <parameter> | <argument> ) { ',' <parameter> }
-//
-// <parameter> ::= <parameter name> '=' <parameter value>
-//
-// <extractor type> ::= NAME
-// <extractor name> ::= NAME | STRING
-// <argument> ::= NUMBER
-// <parameter name> ::= NAME
-// <parameter value> ::= NUMBER | STRING | NAME
-
-#ifndef LIBTEXTCLASSIFIER_COMMON_FML_PARSER_H_
-#define LIBTEXTCLASSIFIER_COMMON_FML_PARSER_H_
-
-#include <string>
-#include <vector>
-
-#include "common/feature-descriptors.h"
-#include "util/base/logging.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-class FMLParser {
- public:
- // Parses fml specification into feature extractor descriptor.
- // Returns true on success, false on error (e.g., syntax errors).
- bool Parse(const std::string &source, FeatureExtractorDescriptor *result);
-
- private:
- // Initializes the parser with the source text.
- // Returns true on success, false on syntax error.
- bool Initialize(const std::string &source);
-
- // Outputs an error message, with context info, and sets error_ to true.
- void ReportError(const std::string &error_message);
-
- // Moves to the next input character.
- void Next();
-
- // Moves to the next input item. Sets item_text_ and item_type_ accordingly.
- // Returns true on success, false on syntax error.
- bool NextItem();
-
- // Parses a feature descriptor.
- // Returns true on success, false on syntax error.
- bool ParseFeature(FeatureFunctionDescriptor *result);
-
- // Parses a parameter specification.
- // Returns true on success, false on syntax error.
- bool ParseParameter(FeatureFunctionDescriptor *result);
-
- // Returns true if end of source input has been reached.
- bool eos() const { return current_ >= source_.end(); }
-
- // Returns current character. Other methods should access the current
- // character through this method (instead of using *current_ directly): this
- // method performs extra safety checks.
- //
- // In case of an unsafe access, returns '\0'.
- char CurrentChar() const {
- if ((current_ >= source_.begin()) && (current_ < source_.end())) {
- return *current_;
- } else {
- TC_LOG(ERROR) << "Unsafe char read";
- return '\0';
- }
- }
-
- // Item types.
- enum ItemTypes {
- END = 0,
- NAME = -1,
- NUMBER = -2,
- STRING = -3,
- };
-
- // Source text.
- std::string source_;
-
- // Current input position.
- std::string::iterator current_;
-
- // Line number for current input position.
- int line_number_;
-
- // Start position for current item.
- std::string::iterator item_start_;
-
- // Start position for current line.
- std::string::iterator line_start_;
-
- // Line number for current item.
- int item_line_number_;
-
- // Item type for current item. If this is positive it is interpreted as a
- // character. If it is negative it is interpreted as an item type.
- int item_type_;
-
- // Text for current item.
- std::string item_text_;
-};
-
-// Converts a FeatureFunctionDescriptor into an FML spec (reverse of parsing).
-void ToFML(const FeatureFunctionDescriptor &function, std::string *output);
-
-// Like ToFML, but doesn't go into the nested functions. Instead, it generates
-// a string that starts with the name of the feature extraction function and
-// next, in-between parentheses, the parameters, separated by comma.
-// Intuitively, the constructed string is the prefix of ToFML, before the "{"
-// that starts the nested features.
-void ToFMLFunction(const FeatureFunctionDescriptor &function,
- std::string *output);
-
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_COMMON_FML_PARSER_H_
diff --git a/common/fml-parser_test.cc b/common/fml-parser_test.cc
deleted file mode 100644
index b46048f..0000000
--- a/common/fml-parser_test.cc
+++ /dev/null
@@ -1,157 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "common/fml-parser.h"
-
-#include "common/feature-descriptors.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-TEST(FMLParserTest, NoFeature) {
- FMLParser fml_parser;
- FeatureExtractorDescriptor descriptor;
- const std::string kFeatureName = "";
- EXPECT_TRUE(fml_parser.Parse(kFeatureName, &descriptor));
- EXPECT_EQ(0, descriptor.feature_size());
-}
-
-TEST(FMLParserTest, FeatureWithNoParams) {
- FMLParser fml_parser;
- FeatureExtractorDescriptor descriptor;
- const std::string kFeatureName = "continuous-bag-of-relevant-scripts";
- EXPECT_TRUE(fml_parser.Parse(kFeatureName, &descriptor));
- EXPECT_EQ(1, descriptor.feature_size());
- EXPECT_EQ(kFeatureName, descriptor.feature(0).type());
-}
-
-TEST(FMLParserTest, FeatureWithOneKeywordParameter) {
- FMLParser fml_parser;
- FeatureExtractorDescriptor descriptor;
- EXPECT_TRUE(fml_parser.Parse("myfeature(start=2)", &descriptor));
- EXPECT_EQ(1, descriptor.feature_size());
- EXPECT_EQ("myfeature", descriptor.feature(0).type());
- EXPECT_EQ(1, descriptor.feature(0).parameter_size());
- EXPECT_EQ("start", descriptor.feature(0).parameter(0).name());
- EXPECT_EQ("2", descriptor.feature(0).parameter(0).value());
- EXPECT_FALSE(descriptor.feature(0).has_argument());
-}
-
-TEST(FMLParserTest, FeatureWithDefaultArgumentNegative) {
- FMLParser fml_parser;
- FeatureExtractorDescriptor descriptor;
- EXPECT_TRUE(fml_parser.Parse("offset(-3)", &descriptor));
- EXPECT_EQ(1, descriptor.feature_size());
- EXPECT_EQ("offset", descriptor.feature(0).type());
- EXPECT_EQ(0, descriptor.feature(0).parameter_size());
- EXPECT_EQ(-3, descriptor.feature(0).argument());
-}
-
-TEST(FMLParserTest, FeatureWithDefaultArgumentPositive) {
- FMLParser fml_parser;
- FeatureExtractorDescriptor descriptor;
- EXPECT_TRUE(fml_parser.Parse("delta(7)", &descriptor));
- EXPECT_EQ(1, descriptor.feature_size());
- EXPECT_EQ("delta", descriptor.feature(0).type());
- EXPECT_EQ(0, descriptor.feature(0).parameter_size());
- EXPECT_EQ(7, descriptor.feature(0).argument());
-}
-
-TEST(FMLParserTest, FeatureWithDefaultArgumentZero) {
- FMLParser fml_parser;
- FeatureExtractorDescriptor descriptor;
- EXPECT_TRUE(fml_parser.Parse("delta(0)", &descriptor));
- EXPECT_EQ(1, descriptor.feature_size());
- EXPECT_EQ("delta", descriptor.feature(0).type());
- EXPECT_EQ(0, descriptor.feature(0).parameter_size());
- EXPECT_EQ(0, descriptor.feature(0).argument());
-}
-
-TEST(FMLParserTest, FeatureWithManyKeywordParameters) {
- FMLParser fml_parser;
- FeatureExtractorDescriptor descriptor;
- EXPECT_TRUE(fml_parser.Parse("myfeature(ratio=0.316,start=2,name=\"foo\")",
- &descriptor));
- EXPECT_EQ(1, descriptor.feature_size());
- EXPECT_EQ("myfeature", descriptor.feature(0).type());
- EXPECT_EQ(3, descriptor.feature(0).parameter_size());
- EXPECT_EQ("ratio", descriptor.feature(0).parameter(0).name());
- EXPECT_EQ("0.316", descriptor.feature(0).parameter(0).value());
- EXPECT_EQ("start", descriptor.feature(0).parameter(1).name());
- EXPECT_EQ("2", descriptor.feature(0).parameter(1).value());
- EXPECT_EQ("name", descriptor.feature(0).parameter(2).name());
- EXPECT_EQ("foo", descriptor.feature(0).parameter(2).value());
- EXPECT_FALSE(descriptor.feature(0).has_argument());
-}
-
-TEST(FMLParserTest, FeatureWithAllKindsOfParameters) {
- FMLParser fml_parser;
- FeatureExtractorDescriptor descriptor;
- EXPECT_TRUE(
- fml_parser.Parse("myfeature(17,ratio=0.316,start=2)", &descriptor));
- EXPECT_EQ(1, descriptor.feature_size());
- EXPECT_EQ("myfeature", descriptor.feature(0).type());
- EXPECT_EQ(2, descriptor.feature(0).parameter_size());
- EXPECT_EQ("ratio", descriptor.feature(0).parameter(0).name());
- EXPECT_EQ("0.316", descriptor.feature(0).parameter(0).value());
- EXPECT_EQ("start", descriptor.feature(0).parameter(1).name());
- EXPECT_EQ("2", descriptor.feature(0).parameter(1).value());
- EXPECT_EQ(17, descriptor.feature(0).argument());
-}
-
-TEST(FMLParserTest, FeatureWithWhitespaces) {
- FMLParser fml_parser;
- FeatureExtractorDescriptor descriptor;
- EXPECT_TRUE(fml_parser.Parse(
- " myfeature\t\t\t\n(17,\nratio=0.316 , start=2) ", &descriptor));
- EXPECT_EQ(1, descriptor.feature_size());
- EXPECT_EQ("myfeature", descriptor.feature(0).type());
- EXPECT_EQ(2, descriptor.feature(0).parameter_size());
- EXPECT_EQ("ratio", descriptor.feature(0).parameter(0).name());
- EXPECT_EQ("0.316", descriptor.feature(0).parameter(0).value());
- EXPECT_EQ("start", descriptor.feature(0).parameter(1).name());
- EXPECT_EQ("2", descriptor.feature(0).parameter(1).value());
- EXPECT_EQ(17, descriptor.feature(0).argument());
-}
-
-TEST(FMLParserTest, Broken_ParamWithoutValue) {
- FMLParser fml_parser;
- FeatureExtractorDescriptor descriptor;
- EXPECT_FALSE(
- fml_parser.Parse("myfeature(17,ratio=0.316,start)", &descriptor));
-}
-
-TEST(FMLParserTest, Broken_MissingCloseParen) {
- FMLParser fml_parser;
- FeatureExtractorDescriptor descriptor;
- EXPECT_FALSE(fml_parser.Parse("myfeature(17,ratio=0.316", &descriptor));
-}
-
-TEST(FMLParserTest, Broken_MissingOpenParen) {
- FMLParser fml_parser;
- FeatureExtractorDescriptor descriptor;
- EXPECT_FALSE(fml_parser.Parse("myfeature17,ratio=0.316)", &descriptor));
-}
-
-TEST(FMLParserTest, Broken_MissingQuote) {
- FMLParser fml_parser;
- FeatureExtractorDescriptor descriptor;
- EXPECT_FALSE(fml_parser.Parse("count(17,name=\"foo)", &descriptor));
-}
-
-} // namespace nlp_core
-} // namespace libtextclassifier
diff --git a/common/list-of-strings.proto b/common/list-of-strings.proto
deleted file mode 100644
index 5ba45ed..0000000
--- a/common/list-of-strings.proto
+++ /dev/null
@@ -1,22 +0,0 @@
-// Copyright (C) 2017 The Android Open Source Project
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-syntax = "proto2";
-option optimize_for = LITE_RUNTIME;
-
-package libtextclassifier.nlp_core;
-
-message ListOfStrings {
- repeated string element = 1;
-}
diff --git a/common/little-endian-data.h b/common/little-endian-data.h
deleted file mode 100644
index e3bc88f..0000000
--- a/common/little-endian-data.h
+++ /dev/null
@@ -1,124 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_COMMON_LITTLE_ENDIAN_DATA_H_
-#define LIBTEXTCLASSIFIER_COMMON_LITTLE_ENDIAN_DATA_H_
-
-#include <algorithm>
-#include <string>
-#include <vector>
-
-#include "util/base/endian.h"
-#include "util/base/logging.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-// Swaps the sizeof(T) bytes that start at addr. E.g., if sizeof(T) == 2,
-// then (addr[0], addr[1]) -> (addr[1], addr[0]). Useful for little endian
-// <-> big endian conversions.
-template <class T>
-void SwapBytes(T *addr) {
- char *char_ptr = reinterpret_cast<char *>(addr);
- std::reverse(char_ptr, char_ptr + sizeof(T));
-}
-
-// Assuming addr points to a piece of data of type T, with its bytes in the
-// little/big endian order specific to the machine this code runs on, this
-// method will re-arrange the bytes (in place) in little-endian order.
-template <class T>
-void HostToLittleEndian(T *addr) {
- if (LittleEndian::IsLittleEndian()) {
- // Do nothing: current machine is little-endian.
- } else {
- SwapBytes(addr);
- }
-}
-
-// Reverse of HostToLittleEndian.
-template <class T>
-void LittleEndianToHost(T *addr) {
- // It turns out it's the same function: on little-endian machines, do nothing
- // (source and target formats are identical). Otherwise, swap bytes.
- HostToLittleEndian(addr);
-}
-
-// Returns string obtained by concatenating the bytes of the elements from a
-// vector (in order: v[0], v[1], etc). If the type T requires more than one
-// byte, the byte for each element are first converted to little-endian format.
-template<typename T>
-std::string GetDataBytesInLittleEndianOrder(const std::vector<T> &v) {
- std::string data_bytes;
- for (const T element : v) {
- T little_endian_element = element;
- HostToLittleEndian(&little_endian_element);
- data_bytes.append(
- reinterpret_cast<const char *>(&little_endian_element),
- sizeof(T));
- }
- return data_bytes;
-}
-
-// Performs reverse of GetDataBytesInLittleEndianOrder.
-//
-// I.e., decodes the data bytes from parameter bytes into num_elements Ts, and
-// places them in the vector v (previous content of that vector is erased).
-//
-// We expect bytes to contain the concatenation of the bytes for exactly
-// num_elements elements of type T. If the type T requires more than one byte,
-// those bytes should be arranged in little-endian form.
-//
-// Returns true on success and false otherwise (e.g., bytes has the wrong size).
-// Note: we do not want to crash on corrupted data (some clients, e..g, GMSCore,
-// have asked us not to do so). Instead, we report the error and let the client
-// decide what to do. On error, we also fill the vector with zeros, such that
-// at least the dimension of v matches expectations.
-template<typename T>
-bool FillVectorFromDataBytesInLittleEndian(
- const std::string &bytes, int num_elements, std::vector<T> *v) {
- if (bytes.size() != num_elements * sizeof(T)) {
- TC_LOG(ERROR) << "Wrong number of bytes: actual " << bytes.size()
- << " vs expected " << num_elements
- << " elements of sizeof(element) = " << sizeof(T)
- << " bytes each ; will fill vector with zeros";
- v->assign(num_elements, static_cast<T>(0));
- return false;
- }
- v->clear();
- v->reserve(num_elements);
- const T *start = reinterpret_cast<const T *>(bytes.data());
- if (LittleEndian::IsLittleEndian() || (sizeof(T) == 1)) {
- // Fast in the common case ([almost] all hardware today is little-endian):
- // if same endianness (or type T requires a single byte and endianness
- // irrelevant), just use the bytes.
- v->assign(start, start + num_elements);
- } else {
- // Slower (but very rare case): this code runs on a big endian machine and
- // the type T requires more than one byte. Hence, some conversion is
- // necessary.
- for (int i = 0; i < num_elements; ++i) {
- T temp = start[i];
- SwapBytes(&temp);
- v->push_back(temp);
- }
- }
- return true;
-}
-
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_COMMON_LITTLE_ENDIAN_DATA_H_
diff --git a/common/memory_image/data-store.cc b/common/memory_image/data-store.cc
deleted file mode 100644
index a5f500c..0000000
--- a/common/memory_image/data-store.cc
+++ /dev/null
@@ -1,61 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "common/memory_image/data-store.h"
-
-#include "util/base/logging.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-namespace memory_image {
-
-DataStore::DataStore(StringPiece bytes)
- : reader_(bytes.data(), bytes.size()) {
- if (!reader_.success_status()) {
- TC_LOG(ERROR) << "Unable to successfully initialize DataStore.";
- }
-}
-
-StringPiece DataStore::GetData(const std::string &name) const {
- if (!reader_.success_status()) {
- TC_LOG(ERROR) << "DataStore::GetData(" << name << ") "
- << "called on invalid DataStore; will return empty data "
- << "chunk";
- return StringPiece();
- }
-
- const auto &entries = reader_.trimmed_proto().entries();
- const auto &it = entries.find(name);
- if (it == entries.end()) {
- TC_LOG(ERROR) << "Unknown key: " << name
- << "; will return empty data chunk";
- return StringPiece();
- }
-
- const DataStoreEntryBytes &entry_bytes = it->second;
- if (!entry_bytes.has_blob_index()) {
- TC_LOG(ERROR) << "DataStoreEntryBytes with no blob_index; "
- << "will return empty data chunk.";
- return StringPiece();
- }
-
- int blob_index = entry_bytes.blob_index();
- return reader_.data_blob_view(blob_index);
-}
-
-} // namespace memory_image
-} // namespace nlp_core
-} // namespace libtextclassifier
diff --git a/common/memory_image/data-store.h b/common/memory_image/data-store.h
deleted file mode 100644
index 56aa4fc..0000000
--- a/common/memory_image/data-store.h
+++ /dev/null
@@ -1,58 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_DATA_STORE_H_
-#define LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_DATA_STORE_H_
-
-#include <string>
-
-#include "common/memory_image/data-store.pb.h"
-#include "common/memory_image/memory-image-reader.h"
-#include "util/strings/stringpiece.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-namespace memory_image {
-
-// Class to access a data store. See usage example in comments for
-// DataStoreBuilder.
-class DataStore {
- public:
- // Constructs a DataStore using the indicated bytes, i.e., bytes.size() bytes
- // starting at address bytes.data(). These bytes should contain the
- // serialization of a data store, see DataStoreBuilder::SerializeAsString().
- explicit DataStore(StringPiece bytes);
-
- // Retrieves (start_addr, num_bytes) info for piece of memory that contains
- // the data associated with the indicated name. Note: this piece of memory is
- // inside the [start, start + size) (see constructor). This piece of memory
- // starts at an offset from start which is a multiple of the alignment
- // specified when the data store was built using DataStoreBuilder.
- //
- // If the alignment is a low power of 2 (e..g, 4, 8, or 16) and "start" passed
- // to constructor corresponds to the beginning of a memory page or an address
- // returned by new or malloc(), then start_addr is divisible with alignment.
- StringPiece GetData(const std::string &name) const;
-
- private:
- MemoryImageReader<DataStoreProto> reader_;
-};
-
-} // namespace memory_image
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_DATA_STORE_H_
diff --git a/common/memory_image/data-store.proto b/common/memory_image/data-store.proto
deleted file mode 100644
index 68e914a..0000000
--- a/common/memory_image/data-store.proto
+++ /dev/null
@@ -1,44 +0,0 @@
-// Copyright (C) 2017 The Android Open Source Project
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Protos for a data store: a barebone in-memory file system.
-//
-// A DataStore maintains an association between names and chunks of bytes. It
-// can be serialized into a string. Of course, it can be deserialized from a
-// string, with minimal parsing; after deserialization, all chunks of bytes
-// start at aligned addresses (aligned = multiple of an address specified at
-// build time).
-
-syntax = "proto2";
-option optimize_for = LITE_RUNTIME;
-
-package libtextclassifier.nlp_core.memory_image;
-
-// Bytes for a data store entry. They can be stored either directly in the
-// "data" field, or in the DataBlob with the 0-based index "blob_index".
-message DataStoreEntryBytes {
- oneof data {
- // Bytes for this data store entry, stored in this message.
- string in_place_data = 1;
-
- // 0-based index of the data blob with bytes for this data store entry. In
- // this case, the actual bytes are stored outside this message; the
- // DataStore code handles the association.
- int32 blob_index = 2 [default = -1];
- }
-}
-
-message DataStoreProto {
- map<string, DataStoreEntryBytes> entries = 1;
-}
diff --git a/common/memory_image/embedding-network-params-from-image.h b/common/memory_image/embedding-network-params-from-image.h
deleted file mode 100644
index e8c7d1e..0000000
--- a/common/memory_image/embedding-network-params-from-image.h
+++ /dev/null
@@ -1,225 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_EMBEDDING_NETWORK_PARAMS_FROM_IMAGE_H_
-#define LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_EMBEDDING_NETWORK_PARAMS_FROM_IMAGE_H_
-
-#include "common/embedding-network-package.pb.h"
-#include "common/embedding-network-params.h"
-#include "common/embedding-network.pb.h"
-#include "common/memory_image/memory-image-reader.h"
-#include "util/base/integral_types.h"
-#include "util/strings/stringpiece.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-// EmbeddingNetworkParams backed by a memory image.
-//
-// In this context, a memory image is like an EmbeddingNetworkProto, but with
-// all repeated weights (>99% of the size) directly usable (with no parsing
-// required).
-class EmbeddingNetworkParamsFromImage : public EmbeddingNetworkParams {
- public:
- // Constructs an EmbeddingNetworkParamsFromImage, using the memory image that
- // starts at address start and contains num_bytes bytes.
- EmbeddingNetworkParamsFromImage(const void *start, uint64 num_bytes)
- : memory_reader_(start, num_bytes),
- trimmed_proto_(memory_reader_.trimmed_proto()) {
- embeddings_blob_offset_ = 0;
-
- hidden_blob_offset_ = embeddings_blob_offset_ + embeddings_size();
- if (trimmed_proto_.embeddings_size() &&
- trimmed_proto_.embeddings(0).is_quantized()) {
- // Adjust for quantization: each quantized matrix takes two blobs (instead
- // of one): one for the quantized values and one for the scales.
- hidden_blob_offset_ += embeddings_size();
- }
-
- hidden_bias_blob_offset_ = hidden_blob_offset_ + hidden_size();
- softmax_blob_offset_ = hidden_bias_blob_offset_ + hidden_bias_size();
- softmax_bias_blob_offset_ = softmax_blob_offset_ + softmax_size();
- }
-
- ~EmbeddingNetworkParamsFromImage() override {}
-
- const TaskSpec *GetTaskSpec() override {
- auto extension_id = task_spec_in_embedding_network_proto;
- if (trimmed_proto_.HasExtension(extension_id)) {
- return &(trimmed_proto_.GetExtension(extension_id));
- } else {
- return nullptr;
- }
- }
-
- protected:
- int embeddings_size() const override {
- return trimmed_proto_.embeddings_size();
- }
-
- int embeddings_num_rows(int i) const override {
- TC_DCHECK(InRange(i, embeddings_size()));
- return trimmed_proto_.embeddings(i).rows();
- }
-
- int embeddings_num_cols(int i) const override {
- TC_DCHECK(InRange(i, embeddings_size()));
- return trimmed_proto_.embeddings(i).cols();
- }
-
- const void *embeddings_weights(int i) const override {
- TC_DCHECK(InRange(i, embeddings_size()));
- const int blob_index = trimmed_proto_.embeddings(i).is_quantized()
- ? (embeddings_blob_offset_ + 2 * i)
- : (embeddings_blob_offset_ + i);
- StringPiece data_blob_view = memory_reader_.data_blob_view(blob_index);
- return data_blob_view.data();
- }
-
- QuantizationType embeddings_quant_type(int i) const override {
- TC_DCHECK(InRange(i, embeddings_size()));
- if (trimmed_proto_.embeddings(i).is_quantized()) {
- return QuantizationType::UINT8;
- } else {
- return QuantizationType::NONE;
- }
- }
-
- const float16 *embeddings_quant_scales(int i) const override {
- TC_DCHECK(InRange(i, embeddings_size()));
- if (trimmed_proto_.embeddings(i).is_quantized()) {
- // Each embedding matrix has two atttached data blobs (hence the "2 * i"):
- // one blob with the quantized values and (immediately after it, hence the
- // "+ 1") one blob with the scales.
- int blob_index = embeddings_blob_offset_ + 2 * i + 1;
- StringPiece data_blob_view = memory_reader_.data_blob_view(blob_index);
- return reinterpret_cast<const float16 *>(data_blob_view.data());
- } else {
- return nullptr;
- }
- }
-
- int hidden_size() const override { return trimmed_proto_.hidden_size(); }
-
- int hidden_num_rows(int i) const override {
- TC_DCHECK(InRange(i, hidden_size()));
- return trimmed_proto_.hidden(i).rows();
- }
-
- int hidden_num_cols(int i) const override {
- TC_DCHECK(InRange(i, hidden_size()));
- return trimmed_proto_.hidden(i).cols();
- }
-
- const void *hidden_weights(int i) const override {
- TC_DCHECK(InRange(i, hidden_size()));
- StringPiece data_blob_view =
- memory_reader_.data_blob_view(hidden_blob_offset_ + i);
- return data_blob_view.data();
- }
-
- int hidden_bias_size() const override {
- return trimmed_proto_.hidden_bias_size();
- }
-
- int hidden_bias_num_rows(int i) const override {
- TC_DCHECK(InRange(i, hidden_bias_size()));
- return trimmed_proto_.hidden_bias(i).rows();
- }
-
- int hidden_bias_num_cols(int i) const override {
- TC_DCHECK(InRange(i, hidden_bias_size()));
- return trimmed_proto_.hidden_bias(i).cols();
- }
-
- const void *hidden_bias_weights(int i) const override {
- TC_DCHECK(InRange(i, hidden_bias_size()));
- StringPiece data_blob_view =
- memory_reader_.data_blob_view(hidden_bias_blob_offset_ + i);
- return data_blob_view.data();
- }
-
- int softmax_size() const override {
- return trimmed_proto_.has_softmax() ? 1 : 0;
- }
-
- int softmax_num_rows(int i) const override {
- TC_DCHECK(InRange(i, softmax_size()));
- return trimmed_proto_.softmax().rows();
- }
-
- int softmax_num_cols(int i) const override {
- TC_DCHECK(InRange(i, softmax_size()));
- return trimmed_proto_.softmax().cols();
- }
-
- const void *softmax_weights(int i) const override {
- TC_DCHECK(InRange(i, softmax_size()));
- StringPiece data_blob_view =
- memory_reader_.data_blob_view(softmax_blob_offset_ + i);
- return data_blob_view.data();
- }
-
- int softmax_bias_size() const override {
- return trimmed_proto_.has_softmax_bias() ? 1 : 0;
- }
-
- int softmax_bias_num_rows(int i) const override {
- TC_DCHECK(InRange(i, softmax_bias_size()));
- return trimmed_proto_.softmax_bias().rows();
- }
-
- int softmax_bias_num_cols(int i) const override {
- TC_DCHECK(InRange(i, softmax_bias_size()));
- return trimmed_proto_.softmax_bias().cols();
- }
-
- const void *softmax_bias_weights(int i) const override {
- TC_DCHECK(InRange(i, softmax_bias_size()));
- StringPiece data_blob_view =
- memory_reader_.data_blob_view(softmax_bias_blob_offset_ + i);
- return data_blob_view.data();
- }
-
- int embedding_num_features_size() const override {
- return trimmed_proto_.embedding_num_features_size();
- }
-
- int embedding_num_features(int i) const override {
- TC_DCHECK(InRange(i, embedding_num_features_size()));
- return trimmed_proto_.embedding_num_features(i);
- }
-
- private:
- MemoryImageReader<EmbeddingNetworkProto> memory_reader_;
-
- const EmbeddingNetworkProto &trimmed_proto_;
-
- // 0-based offsets in the list of data blobs for the different MatrixParams
- // fields. E.g., the 1st hidden MatrixParams has its weights stored in the
- // data blob number hidden_blob_offset_, the 2nd one in hidden_blob_offset_ +
- // 1, and so on.
- int embeddings_blob_offset_;
- int hidden_blob_offset_;
- int hidden_bias_blob_offset_;
- int softmax_blob_offset_;
- int softmax_bias_blob_offset_;
-};
-
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_EMBEDDING_NETWORK_PARAMS_FROM_IMAGE_H_
diff --git a/common/memory_image/in-memory-model-data.cc b/common/memory_image/in-memory-model-data.cc
deleted file mode 100644
index acf3d86..0000000
--- a/common/memory_image/in-memory-model-data.cc
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "common/memory_image/in-memory-model-data.h"
-
-#include "common/file-utils.h"
-#include "util/base/logging.h"
-#include "util/strings/stringpiece.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-const char InMemoryModelData::kTaskSpecDataStoreEntryName[] = "TASK-SPEC-#@";
-const char InMemoryModelData::kFilePatternPrefix[] = "in-mem-model::";
-
-bool InMemoryModelData::GetTaskSpec(TaskSpec *task_spec) const {
- StringPiece blob = data_store_.GetData(kTaskSpecDataStoreEntryName);
- if (blob.data() == nullptr) {
- TC_LOG(ERROR) << "Can't find data blob for TaskSpec, i.e., entry "
- << kTaskSpecDataStoreEntryName;
- return false;
- }
- bool parse_status = file_utils::ParseProtoFromMemory(blob, task_spec);
- if (!parse_status) {
- TC_LOG(ERROR) << "Error parsing TaskSpec";
- return false;
- }
- return true;
-}
-
-} // namespace nlp_core
-} // namespace libtextclassifier
diff --git a/common/memory_image/in-memory-model-data.h b/common/memory_image/in-memory-model-data.h
deleted file mode 100644
index 91e4436..0000000
--- a/common/memory_image/in-memory-model-data.h
+++ /dev/null
@@ -1,76 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_IN_MEMORY_MODEL_DATA_H_
-#define LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_IN_MEMORY_MODEL_DATA_H_
-
-#include "common/memory_image/data-store.h"
-#include "common/task-spec.pb.h"
-#include "util/strings/stringpiece.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-// In-memory representation of data for a Saft model. Provides access to a
-// TaskSpec object (produced by the "spec" stage of the Saft training model) and
-// to the bytes of the TaskInputs mentioned in that spec (all these bytes are in
-// memory, no file I/O required).
-//
-// Technically, an InMemoryModelData is a DataStore that maps the special string
-// kTaskSpecDataStoreEntryName to the binary serialization of a TaskSpec. For
-// each TaskInput (of the TaskSpec) with a file_pattern that starts with
-// kFilePatternPrefix (see below), the same DataStore maps file_pattern to some
-// content bytes. This way, it is possible to have all TaskInputs in memory,
-// while still allowing classic, on-disk TaskInputs.
-class InMemoryModelData {
- public:
- // Name for the DataStore entry that stores the serialized TaskSpec for the
- // entire model.
- static const char kTaskSpecDataStoreEntryName[];
-
- // Returns prefix for TaskInput::Part::file_pattern, to distinguish those
- // "files" from other files.
- static const char kFilePatternPrefix[];
-
- // Constructs an InMemoryModelData based on a chunk of bytes. Those bytes
- // should have been produced by a DataStoreBuilder.
- explicit InMemoryModelData(StringPiece bytes) : data_store_(bytes) {}
-
- // Fills *task_spec with a TaskSpec similar to the one used by
- // DataStoreBuilder (when building the bytes used to construct this
- // InMemoryModelData) except that each file name
- // (TaskInput::Part::file_pattern) is replaced with a name that can be used to
- // retrieve the corresponding file content bytes via GetBytesForInputFile().
- //
- // Returns true on success, false otherwise.
- bool GetTaskSpec(TaskSpec *task_spec) const;
-
- // Gets content bytes for a file. The file_name argument should be the
- // file_pattern for a TaskInput from the TaskSpec (see GetTaskSpec()).
- // Returns a StringPiece indicating a memory area with the content bytes. On
- // error, returns StringPiece(nullptr, 0).
- StringPiece GetBytesForInputFile(const std::string &file_name) const {
- return data_store_.GetData(file_name);
- }
-
- private:
- const memory_image::DataStore data_store_;
-};
-
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_IN_MEMORY_MODEL_DATA_H_
diff --git a/common/memory_image/low-level-memory-reader.h b/common/memory_image/low-level-memory-reader.h
deleted file mode 100644
index c87c772..0000000
--- a/common/memory_image/low-level-memory-reader.h
+++ /dev/null
@@ -1,142 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_LOW_LEVEL_MEMORY_READER_H_
-#define LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_LOW_LEVEL_MEMORY_READER_H_
-
-#include <string.h>
-
-#include <string>
-
-#include "util/base/endian.h"
-#include "util/base/integral_types.h"
-#include "util/base/logging.h"
-#include "util/strings/stringpiece.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-class LowLevelMemReader {
- public:
- // Constructs a MemReader instance that reads at most num_available_bytes
- // starting from address start.
- LowLevelMemReader(const void *start, uint64 num_available_bytes)
- : current_(reinterpret_cast<const char *>(start)),
- // 0 bytes available if start == nullptr
- num_available_bytes_(start ? num_available_bytes : 0),
- num_loaded_bytes_(0) {
- }
-
- // Copies length bytes of data to address target. Advances current position
- // and returns true on success and false otherwise.
- bool Read(void *target, uint64 length) {
- if (length > num_available_bytes_) {
- TC_LOG(WARNING) << "Not enough bytes: available " << num_available_bytes_
- << " < required " << length;
- return false;
- }
- memcpy(target, current_, length);
- Advance(length);
- return true;
- }
-
- // Reads the string encoded at the current position. The bytes starting at
- // current position should contain (1) little-endian uint32 size (in bytes) of
- // the actual string and next (2) the actual bytes of the string. Advances
- // the current position and returns true if successful, false otherwise.
- //
- // On success, sets *view to be a view of the relevant bytes: view.data()
- // points to the beginning of the string bytes, and view.size() is the number
- // of such bytes.
- bool ReadString(StringPiece *view) {
- uint32 size;
- if (!Read(&size, sizeof(size))) {
- TC_LOG(ERROR) << "Unable to read std::string size";
- return false;
- }
- size = LittleEndian::ToHost32(size);
- if (size > num_available_bytes_) {
- TC_LOG(WARNING) << "Not enough bytes: " << num_available_bytes_
- << " available < " << size << " required ";
- return false;
- }
- *view = StringPiece(current_, size);
- Advance(size);
- return true;
- }
-
- // Like ReadString(StringPiece *) but reads directly into a C++ string,
- // instead of a StringPiece (StringPiece-like object).
- bool ReadString(std::string *target) {
- StringPiece view;
- if (!ReadString(&view)) {
- return false;
- }
- *target = view.ToString();
- return true;
- }
-
- // Returns current position.
- const char *GetCurrent() const { return current_; }
-
- // Returns remaining number of available bytes.
- uint64 GetNumAvailableBytes() const { return num_available_bytes_; }
-
- // Returns number of bytes read ("loaded") so far.
- uint64 GetNumLoadedBytes() const { return num_loaded_bytes_; }
-
- // Advance the current read position by indicated number of bytes. Returns
- // true on success, false otherwise (e.g., if there are not enough available
- // bytes to advance num_bytes).
- bool Advance(uint64 num_bytes) {
- if (num_bytes > num_available_bytes_) {
- return false;
- }
-
- // Next line never results in an underflow of the unsigned
- // num_available_bytes_, due to the previous if.
- num_available_bytes_ -= num_bytes;
- current_ += num_bytes;
- num_loaded_bytes_ += num_bytes;
- return true;
- }
-
- // Advance current position to nearest multiple of alignment. Returns false
- // if not enough bytes available to do that, true (success) otherwise.
- bool SkipToAlign(int alignment) {
- int num_extra_bytes = num_loaded_bytes_ % alignment;
- if (num_extra_bytes == 0) {
- return true;
- }
- return Advance(alignment - num_extra_bytes);
- }
-
- private:
- // Current position in the in-memory data. Next call to Read() will read from
- // this address.
- const char *current_;
-
- // Number of available bytes we can still read.
- uint64 num_available_bytes_;
-
- // Number of bytes read ("loaded") so far.
- uint64 num_loaded_bytes_;
-};
-
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_LOW_LEVEL_MEMORY_READER_H_
diff --git a/common/memory_image/memory-image-common.cc b/common/memory_image/memory-image-common.cc
deleted file mode 100644
index 6debf1d..0000000
--- a/common/memory_image/memory-image-common.cc
+++ /dev/null
@@ -1,31 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "common/memory_image/memory-image-common.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-// IMPORTANT: this signature should never change. If you change the protocol,
-// update kCurrentVersion, *not* this signature.
-const char MemoryImageConstants::kSignature[] = "Memory image $5%1#o3-1x32";
-
-const int MemoryImageConstants::kCurrentVersion = 1;
-
-const int MemoryImageConstants::kDefaultAlignment = 16;
-
-} // namespace nlp_core
-} // namespace libtextclassifier
diff --git a/common/memory_image/memory-image-common.h b/common/memory_image/memory-image-common.h
deleted file mode 100644
index 3a46f49..0000000
--- a/common/memory_image/memory-image-common.h
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Common utils for memory images.
-
-#ifndef LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_MEMORY_IMAGE_COMMON_H_
-#define LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_MEMORY_IMAGE_COMMON_H_
-
-#include <stddef.h>
-
-#include <string>
-
-#include "util/strings/stringpiece.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-class MemoryImageConstants {
- public:
- static const char kSignature[];
- static const int kCurrentVersion;
- static const int kDefaultAlignment;
-};
-
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_MEMORY_IMAGE_COMMON_H_
diff --git a/common/memory_image/memory-image-reader.cc b/common/memory_image/memory-image-reader.cc
deleted file mode 100644
index 7e717d5..0000000
--- a/common/memory_image/memory-image-reader.cc
+++ /dev/null
@@ -1,125 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "common/memory_image/memory-image-reader.h"
-
-#include <string>
-
-#include "common/memory_image/low-level-memory-reader.h"
-#include "common/memory_image/memory-image-common.h"
-#include "common/memory_image/memory-image.pb.h"
-#include "util/base/endian.h"
-#include "util/base/logging.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-namespace {
-
-// Checks that the memory area read by mem_reader starts with the expected
-// signature. Advances mem_reader past the signature and returns success
-// status.
-bool ReadAndCheckSignature(LowLevelMemReader *mem_reader) {
- const std::string expected_signature = MemoryImageConstants::kSignature;
- const int signature_size = expected_signature.size();
- if (mem_reader->GetNumAvailableBytes() < signature_size) {
- TC_LOG(ERROR) << "Not enough bytes to check signature";
- return false;
- }
- const std::string actual_signature(mem_reader->GetCurrent(), signature_size);
- if (!mem_reader->Advance(signature_size)) {
- TC_LOG(ERROR) << "Failed to advance past signature";
- return false;
- }
- if (actual_signature != expected_signature) {
- TC_LOG(ERROR) << "Different signature: actual \"" << actual_signature
- << "\" != expected \"" << expected_signature << "\"";
- return false;
- }
- return true;
-}
-
-// Parses MemoryImageHeader from mem_reader. Advances mem_reader past it.
-// Returns success status.
-bool ParseMemoryImageHeader(
- LowLevelMemReader *mem_reader, MemoryImageHeader *header) {
- std::string header_proto_str;
- if (!mem_reader->ReadString(&header_proto_str)) {
- TC_LOG(ERROR) << "Unable to read header_proto_str";
- return false;
- }
- if (!header->ParseFromString(header_proto_str)) {
- TC_LOG(ERROR) << "Unable to parse MemoryImageHeader";
- return false;
- }
- return true;
-}
-
-} // namespace
-
-bool GeneralMemoryImageReader::ReadMemoryImage() {
- LowLevelMemReader mem_reader(start_, num_bytes_);
-
- // Read and check signature.
- if (!ReadAndCheckSignature(&mem_reader)) {
- return false;
- }
-
- // Parse MemoryImageHeader header_.
- if (!ParseMemoryImageHeader(&mem_reader, &header_)) {
- return false;
- }
-
- // Check endianness.
- if (header_.is_little_endian() != LittleEndian::IsLittleEndian()) {
- // TODO(salcianu): implement conversion: it will take time, but it's better
- // than crashing. Not very urgent: [almost] all current Android phones are
- // little-endian.
- TC_LOG(ERROR) << "Memory image is "
- << (header_.is_little_endian() ? "little" : "big")
- << " endian. "
- << "Local system is different and we don't currently support "
- << "conversion between the two.";
- return false;
- }
-
- // Read binary serialization of trimmed original proto.
- if (!mem_reader.ReadString(&trimmed_proto_serialization_)) {
- TC_LOG(ERROR) << "Unable to read trimmed proto binary serialization";
- return false;
- }
-
- // Fill vector of pointers to beginning of each data blob.
- for (int i = 0; i < header_.blob_info_size(); ++i) {
- const MemoryImageDataBlobInfo &blob_info = header_.blob_info(i);
- if (!mem_reader.SkipToAlign(header_.alignment())) {
- TC_LOG(ERROR) << "Unable to align for blob #i" << i;
- return false;
- }
- data_blob_views_.emplace_back(
- mem_reader.GetCurrent(),
- blob_info.num_bytes());
- if (!mem_reader.Advance(blob_info.num_bytes())) {
- TC_LOG(ERROR) << "Not enough bytes for blob #i" << i;
- return false;
- }
- }
-
- return true;
-}
-
-} // namespace nlp_core
-} // namespace libtextclassifier
diff --git a/common/memory_image/memory-image-reader.h b/common/memory_image/memory-image-reader.h
deleted file mode 100644
index c5954fd..0000000
--- a/common/memory_image/memory-image-reader.h
+++ /dev/null
@@ -1,154 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// MemoryImageReader, class for reading a memory image.
-
-#ifndef LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_MEMORY_IMAGE_READER_H_
-#define LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_MEMORY_IMAGE_READER_H_
-
-#include <string>
-#include <vector>
-
-#include "common/memory_image/memory-image.pb.h"
-#include "util/base/integral_types.h"
-#include "util/base/logging.h"
-#include "util/base/macros.h"
-#include "util/strings/stringpiece.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-// General, non-templatized class, to reduce code duplication.
-//
-// Given a memory area (pointer to start + size in bytes) parses a memory image
-// from there into (1) MemoryImageHeader proto (it includes the serialized form
-// of the trimmed down original proto) and (2) a list of void* pointers to the
-// beginning of all data blobs.
-//
-// In case of parsing errors, we prefer to log the error and set the
-// success_status() to false, instead of CHECK-failing . This way, the client
-// has the option of performing error recovery or crashing. Some mobile apps
-// don't like crashing (a restart is very slow) so, if possible, we try to avoid
-// that.
-class GeneralMemoryImageReader {
- public:
- // Constructs this object. See class-level comments. Note: the memory area
- // pointed to by start should not be deallocated while this object is used:
- // this object does not copy it; instead, it keeps pointers inside that memory
- // area.
- GeneralMemoryImageReader(const void *start, uint64 num_bytes)
- : start_(start), num_bytes_(num_bytes) {
- success_ = ReadMemoryImage();
- }
-
- virtual ~GeneralMemoryImageReader() {}
-
- // Returns true if reading the memory image has been successful. If this
- // returns false, then none of the other accessors should be used.
- bool success_status() const { return success_; }
-
- // Returns number of data blobs from the memory image.
- int num_data_blobs() const {
- return data_blob_views_.size();
- }
-
- // Returns pointer to the beginning of the data blob #i.
- StringPiece data_blob_view(int i) const {
- if ((i < 0) || (i >= num_data_blobs())) {
- TC_LOG(ERROR) << "Blob index " << i << " outside range [0, "
- << num_data_blobs() << "); will return empty data chunk";
- return StringPiece();
- }
- return data_blob_views_[i];
- }
-
- // Returns std::string with binary serialization of the original proto, but
- // trimmed of the large fields (those were placed in the data blobs).
- std::string trimmed_proto_str() const {
- return trimmed_proto_serialization_.ToString();
- }
-
- // Same as above but returns the trimmed proto as a string piece pointing to
- // the image.
- StringPiece trimmed_proto_view() const {
- return trimmed_proto_serialization_;
- }
-
- const MemoryImageHeader &header() { return header_; }
-
- protected:
- void set_as_failed() {
- success_ = false;
- }
-
- private:
- bool ReadMemoryImage();
-
- // Pointer to beginning of memory image. Not owned.
- const void *const start_;
-
- // Number of bytes in the memory image. This class will not read more bytes.
- const uint64 num_bytes_;
-
- // MemoryImageHeader parsed from the memory image.
- MemoryImageHeader header_;
-
- // Binary serialization of the trimmed version of the original proto.
- // Represented as a StringPiece backed up by the underlying memory image
- // bytes.
- StringPiece trimmed_proto_serialization_;
-
- // List of StringPiece objects for all data blobs from the memory image (in
- // order).
- std::vector<StringPiece> data_blob_views_;
-
- // Memory reading success status.
- bool success_;
-
- TC_DISALLOW_COPY_AND_ASSIGN(GeneralMemoryImageReader);
-};
-
-// Like GeneralMemoryImageReader, but has knowledge about the type of the
-// original proto. As such, it can parse it (well, the trimmed version) and
-// offer access to it.
-//
-// Template parameter T should be the type of the original proto.
-template<class T>
-class MemoryImageReader : public GeneralMemoryImageReader {
- public:
- MemoryImageReader(const void *start, uint64 num_bytes)
- : GeneralMemoryImageReader(start, num_bytes) {
- if (!trimmed_proto_.ParseFromString(trimmed_proto_str())) {
- TC_LOG(INFO) << "Unable to parse the trimmed proto";
- set_as_failed();
- }
- }
-
- // Returns const reference to the trimmed version of the original proto.
- // Useful for retrieving the many small fields that are not converted into
- // data blobs.
- const T &trimmed_proto() const { return trimmed_proto_; }
-
- private:
- T trimmed_proto_;
-
- TC_DISALLOW_COPY_AND_ASSIGN(MemoryImageReader);
-};
-
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_MEMORY_IMAGE_READER_H_
diff --git a/common/memory_image/memory-image.proto b/common/memory_image/memory-image.proto
deleted file mode 100644
index f6b624c..0000000
--- a/common/memory_image/memory-image.proto
+++ /dev/null
@@ -1,63 +0,0 @@
-// Copyright (C) 2017 The Android Open Source Project
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Protos for "memory images".
-
-syntax = "proto2";
-option optimize_for = LITE_RUNTIME;
-
-package libtextclassifier.nlp_core;
-
-message MemoryImageDataBlobInfo {
- // Size (in bytes) of this data blob.
- optional uint64 num_bytes = 1;
-
- // Indicates whether this data blob corresponds to an array.
- optional bool is_array = 2 [default = true];
-
- // Size (in bytes) of each array element. Useful for little <-> big endian
- // conversions. -1 means unknown: no endianness conversion in that case.
- optional int32 element_size = 3 [default = -1];
-}
-
-message MemoryImageHeader {
- // Version of the algorithm used to produce the memory image. We should
- // increase the value used here every time we perform an incompatible change.
- // Algorithm version v should handle only memory images of the same version,
- // and crash otherwise.
- optional int32 version = 1 [default = -1];
-
- // True if the info stored in the data blobs uses the little endian
- // convention. Almost all machines today are little-endian but we want to be
- // able to crash with an informative message or perform a (costly) conversion
- // in the rare cases when that's not true.
- optional bool is_little_endian = 2 [default = true];
-
- // Alignment (in bytes) for all data blobs. E.g., if this field is 16, then
- // each data blob starts at an offset that's a multiple of 16, where the
- // offset is measured from the beginning of the memory image. On the client
- // side, allocating the entire memory image at an aligned address (by same
- // alignment) makes sure all data blobs are properly aligned.
- //
- // NOTE: I (salcianu) explored the idea of a different alignment for each data
- // blob: e.g., float[] should be fine with 4-byte alignment (alignment = 4)
- // but char[] are fine with no alignment (alignment = 1). As we expect only a
- // few (but large) data blobs, the space benefit is not worth the extra code
- // complexity.
- optional int32 alignment = 3 [default = 8];
-
- // One MemoryImageDataBlobInfo for each data blob, in order. There is one
- // data blob for each large field we handle specially.
- repeated MemoryImageDataBlobInfo blob_info = 4;
-}
diff --git a/common/mock_functions.cc b/common/mock_functions.cc
deleted file mode 100644
index c661b70..0000000
--- a/common/mock_functions.cc
+++ /dev/null
@@ -1,29 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "common/mock_functions.h"
-
-#include "common/registry.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-TC_DEFINE_CLASS_REGISTRY_NAME("function", functions::Function);
-
-TC_DEFINE_CLASS_REGISTRY_NAME("int-function", functions::IntFunction);
-
-} // namespace nlp_core
-} // namespace libtextclassifier
diff --git a/common/mock_functions.h b/common/mock_functions.h
deleted file mode 100644
index b5bcb07..0000000
--- a/common/mock_functions.h
+++ /dev/null
@@ -1,75 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_COMMON_MOCK_FUNCTIONS_H_
-#define LIBTEXTCLASSIFIER_COMMON_MOCK_FUNCTIONS_H_
-
-#include <math.h>
-
-#include "common/registry.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-namespace functions {
-
-// Abstract double -> double function.
-class Function : public RegisterableClass<Function> {
- public:
- virtual ~Function() {}
- virtual double Evaluate(double x) = 0;
-};
-
-class Cos : public Function {
- public:
- double Evaluate(double x) override { return cos(x); }
- TC_DEFINE_REGISTRATION_METHOD("cos", Cos);
-};
-
-class Exp : public Function {
- public:
- double Evaluate(double x) override { return exp(x); }
- TC_DEFINE_REGISTRATION_METHOD("exp", Exp);
-};
-
-// Abstract int -> int function.
-class IntFunction : public RegisterableClass<IntFunction> {
- public:
- virtual ~IntFunction() {}
- virtual int Evaluate(int k) = 0;
-};
-
-class Inc : public IntFunction {
- public:
- int Evaluate(int k) override { return k + 1; }
- TC_DEFINE_REGISTRATION_METHOD("inc", Inc);
-};
-
-class Dec : public IntFunction {
- public:
- int Evaluate(int k) override { return k + 1; }
- TC_DEFINE_REGISTRATION_METHOD("dec", Dec);
-};
-
-} // namespace functions
-
-// Should be inside namespace libtextclassifier::nlp_core.
-TC_DECLARE_CLASS_REGISTRY_NAME(functions::Function);
-TC_DECLARE_CLASS_REGISTRY_NAME(functions::IntFunction);
-
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_COMMON_MOCK_FUNCTIONS_H_
diff --git a/common/registry.h b/common/registry.h
deleted file mode 100644
index d958225..0000000
--- a/common/registry.h
+++ /dev/null
@@ -1,281 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Mechanism to instantiate classes by name.
-//
-// This mechanism is useful if the concrete classes to be instantiated are not
-// statically known (e.g., if their names are read from a dynamically-provided
-// config).
-//
-// In that case, the first step is to define the API implemented by the
-// instantiated classes. E.g.,
-//
-// // In a header file function.h:
-//
-// // Abstract function that takes a double and returns a double.
-// class Function : public RegisterableClass<Function> {
-// public:
-// virtual ~Function() {}
-// virtual double Evaluate(double x) = 0;
-// };
-//
-// // Should be inside namespace libtextclassifier::nlp_core.
-// TC_DECLARE_CLASS_REGISTRY_NAME(Function);
-//
-// Notice the inheritance from RegisterableClass<Function>. RegisterableClass
-// is defined by this file (registry.h). Under the hood, this inheritanace
-// defines a "registry" that maps names (zero-terminated arrays of chars) to
-// factory methods that create Functions. You should give a human-readable name
-// to this registry. To do that, use the following macro in a .cc file (it has
-// to be a .cc file, as it defines some static data):
-//
-// // Inside function.cc
-// // Should be inside namespace libtextclassifier::nlp_core.
-// TC_DEFINE_CLASS_REGISTRY_NAME("function", Function);
-//
-// Now, let's define a few concrete Functions: e.g.,
-//
-// class Cos : public Function {
-// public:
-// double Evaluate(double x) override { return cos(x); }
-// TC_DEFINE_REGISTRATION_METHOD("cos", Cos);
-// };
-//
-// class Exp : public Function {
-// public:
-// double Evaluate(double x) override { return exp(x); }
-// TC_DEFINE_REGISTRATION_METHOD("sin", Sin);
-// };
-//
-// Each concrete Function implementation should have (in the public section) the
-// macro
-//
-// TC_DEFINE_REGISTRATION_METHOD("name", implementation_class);
-//
-// This defines a RegisterClass static method that, when invoked, associates
-// "name" with a factory method that creates instances of implementation_class.
-//
-// Before instantiating Functions by name, we need to tell our system which
-// Functions we may be interested in. This is done by calling the
-// Foo::RegisterClass() for each relevant Foo implementation of Function. It is
-// ok to call Foo::RegisterClass() multiple times (even in parallel): only the
-// first call will perform something, the others will return immediately.
-//
-// Cos::RegisterClass();
-// Exp::RegisterClass();
-//
-// Now, let's instantiate a Function based on its name. This get a lot more
-// interesting if the Function name is not statically known (i.e.,
-// read from an input proto:
-//
-// std::unique_ptr<Function> f(Function::Create("cos"));
-// double result = f->Evaluate(arg);
-//
-// NOTE: the same binary can use this mechanism for different APIs. E.g., one
-// can also have (in the binary with Function, Sin, Cos, etc):
-//
-// class IntFunction : public RegisterableClass<IntFunction> {
-// public:
-// virtual ~IntFunction() {}
-// virtual int Evaluate(int k) = 0;
-// };
-//
-// TC_DECLARE_CLASS_REGISTRY_NAME(IntFunction);
-//
-// TC_DEFINE_CLASS_REGISTRY_NAME("int function", IntFunction);
-//
-// class Inc : public IntFunction {
-// public:
-// int Evaluate(int k) override { return k + 1; }
-// TC_DEFINE_REGISTRATION_METHOD("inc", Inc);
-// };
-//
-// RegisterableClass<Function> and RegisterableClass<IntFunction> define their
-// own registries: each maps string names to implementation of the corresponding
-// API.
-
-#ifndef LIBTEXTCLASSIFIER_COMMON_REGISTRY_H_
-#define LIBTEXTCLASSIFIER_COMMON_REGISTRY_H_
-
-#include <stdlib.h>
-#include <string.h>
-
-#include <string>
-
-#include "util/base/logging.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-namespace internal {
-// Registry that associates keys (zero-terminated array of chars) with values.
-// Values are pointers to type T (the template parameter). This is used to
-// store the association between component names and factory methods that
-// produce those components; the error messages are focused on that case.
-//
-// Internally, this registry uses a linked list of (key, value) pairs. We do
-// not use an STL map, list, etc because we aim for small code size.
-template <class T>
-class ComponentRegistry {
- public:
- explicit ComponentRegistry(const char *name) : name_(name), head_(nullptr) {}
-
- // Adds a the (key, value) pair to this registry (if the key does not already
- // exists in this registry) and returns true. If the registry already has a
- // mapping for key, returns false and does not modify the registry. NOTE: the
- // error (false) case happens even if the existing value for key is equal with
- // the new one.
- //
- // This method does not take ownership of key, nor of value.
- bool Add(const char *key, T *value) {
- const Cell *old_cell = FindCell(key);
- if (old_cell != nullptr) {
- TC_LOG(ERROR) << "Duplicate component: " << key;
- return false;
- }
- Cell *new_cell = new Cell(key, value, head_);
- head_ = new_cell;
- return true;
- }
-
- // Returns the value attached to a key in this registry. Returns nullptr on
- // error (e.g., unknown key).
- T *Lookup(const char *key) const {
- const Cell *cell = FindCell(key);
- if (cell == nullptr) {
- TC_LOG(ERROR) << "Unknown " << name() << " component: " << key;
- }
- return (cell == nullptr) ? nullptr : cell->value();
- }
-
- T *Lookup(const std::string &key) const { return Lookup(key.c_str()); }
-
- // Returns name of this ComponentRegistry.
- const char *name() const { return name_; }
-
- private:
- // Cell for the singly-linked list underlying this ComponentRegistry. Each
- // cell contains a key, the value for that key, as well as a pointer to the
- // next Cell from the list.
- class Cell {
- public:
- // Constructs a new Cell.
- Cell(const char *key, T *value, Cell *next)
- : key_(key), value_(value), next_(next) {}
-
- const char *key() const { return key_; }
- T *value() const { return value_; }
- Cell *next() const { return next_; }
-
- private:
- const char *const key_;
- T *const value_;
- Cell *const next_;
- };
-
- // Finds Cell for indicated key in the singly-linked list pointed to by head_.
- // Returns pointer to that first Cell with that key, or nullptr if no such
- // Cell (i.e., unknown key).
- //
- // Caller does NOT own the returned pointer.
- const Cell *FindCell(const char *key) const {
- Cell *c = head_;
- while (c != nullptr && strcmp(key, c->key()) != 0) {
- c = c->next();
- }
- return c;
- }
-
- // Human-readable description for this ComponentRegistry. For debug purposes.
- const char *const name_;
-
- // Pointer to the first Cell from the underlying list of (key, value) pairs.
- Cell *head_;
-};
-} // namespace internal
-
-// Base class for registerable classes.
-template <class T>
-class RegisterableClass {
- public:
- // Factory function type.
- typedef T *(Factory)();
-
- // Registry type.
- typedef internal::ComponentRegistry<Factory> Registry;
-
- // Creates a new instance of T. Returns pointer to new instance or nullptr in
- // case of errors (e.g., unknown component).
- //
- // Passes ownership of the returned pointer to the caller.
- static T *Create(const std::string &name) { // NOLINT
- auto *factory = registry()->Lookup(name);
- if (factory == nullptr) {
- TC_LOG(ERROR) << "Unknown RegisterableClass " << name;
- return nullptr;
- }
- return factory();
- }
-
- // Returns registry for class.
- static Registry *registry() {
- static Registry *registry_for_type_t = new Registry(kRegistryName);
- return registry_for_type_t;
- }
-
- protected:
- // Factory method for subclass ComponentClass. Used internally by the static
- // method RegisterClass() defined by TC_DEFINE_REGISTRATION_METHOD.
- template <class ComponentClass>
- static T *_internal_component_factory() {
- return new ComponentClass();
- }
-
- private:
- // Human-readable name for the registry for this class.
- static const char kRegistryName[];
-};
-
-// Defines the static method component_class::RegisterClass() that should be
-// called before trying to instantiate component_class by name. Should be used
-// inside the public section of the declaration of component_class. See
-// comments at the top-level of this file.
-#define TC_DEFINE_REGISTRATION_METHOD(component_name, component_class) \
- static void RegisterClass() { \
- static bool once = registry()->Add( \
- component_name, &_internal_component_factory<component_class>); \
- if (!once) { \
- TC_LOG(ERROR) << "Problem registering " << component_name; \
- } \
- TC_DCHECK(once); \
- }
-
-// Defines the human-readable name of the registry associated with base_class.
-#define TC_DECLARE_CLASS_REGISTRY_NAME(base_class) \
- template <> \
- const char ::libtextclassifier::nlp_core::RegisterableClass< \
- base_class>::kRegistryName[]
-
-// Defines the human-readable name of the registry associated with base_class.
-#define TC_DEFINE_CLASS_REGISTRY_NAME(registry_name, base_class) \
- template <> \
- const char ::libtextclassifier::nlp_core::RegisterableClass< \
- base_class>::kRegistryName[] = registry_name
-
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_COMMON_REGISTRY_H_
diff --git a/common/registry_test.cc b/common/registry_test.cc
deleted file mode 100644
index d5d7006..0000000
--- a/common/registry_test.cc
+++ /dev/null
@@ -1,68 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <memory>
-
-#include "common/mock_functions.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-namespace functions {
-
-TEST(RegistryTest, InstantiateFunctionsByName) {
- // First, we need to register the functions we are interested in:
- Exp::RegisterClass();
- Inc::RegisterClass();
- Cos::RegisterClass();
-
- // RegisterClass methods can be called in any order, even multiple times :)
- Cos::RegisterClass();
- Inc::RegisterClass();
- Inc::RegisterClass();
- Cos::RegisterClass();
- Inc::RegisterClass();
-
- // NOTE: we intentionally do not register Dec. Attempts to create an instance
- // of that function by name should fail.
-
- // Instantiate a few functions and check that the created functions produce
- // the expected results for a few sample values.
- std::unique_ptr<Function> f1(Function::Create("cos"));
- ASSERT_NE(f1, nullptr);
- std::unique_ptr<Function> f2(Function::Create("exp"));
- ASSERT_NE(f2, nullptr);
- EXPECT_NEAR(f1->Evaluate(-3), -0.9899, 0.0001);
- EXPECT_NEAR(f2->Evaluate(2.3), 9.9741, 0.0001);
-
- std::unique_ptr<IntFunction> f3(IntFunction::Create("inc"));
- ASSERT_NE(f3, nullptr);
- EXPECT_EQ(f3->Evaluate(7), 8);
-
- // Instantiating unknown functions should return nullptr, but not crash
- // anything.
- EXPECT_EQ(Function::Create("mambo"), nullptr);
-
- // Functions that are defined in the code, but are not registered are unknown.
- EXPECT_EQ(IntFunction::Create("dec"), nullptr);
-
- // Function and IntFunction use different registries.
- EXPECT_EQ(IntFunction::Create("exp"), nullptr);
-}
-
-} // namespace functions
-} // namespace nlp_core
-} // namespace libtextclassifier
diff --git a/common/simple-adder.h b/common/simple-adder.h
deleted file mode 100644
index c16cc8a..0000000
--- a/common/simple-adder.h
+++ /dev/null
@@ -1,68 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_COMMON_SIMPLE_ADDER_H_
-#define LIBTEXTCLASSIFIER_COMMON_SIMPLE_ADDER_H_
-
-#include "util/base/integral_types.h"
-#include "util/base/port.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-// Implements add and scaleadd in the most straight-forward way, and it doesn't
-// have any additional requirement on the alignment and array size.
-class SimpleAdder {
- public:
- TC_ATTRIBUTE_ALWAYS_INLINE SimpleAdder(float *dest, int num_floats)
- : dest_(dest), num_floats_(num_floats) {}
-
- TC_ATTRIBUTE_ALWAYS_INLINE void LazyAdd(const float *source) const {
- AddImpl(source, num_floats_, dest_);
- }
-
- TC_ATTRIBUTE_ALWAYS_INLINE void LazyScaleAdd(const float *source,
- const float scale) const {
- ScaleAddImpl(source, num_floats_, scale, dest_);
- }
-
- // Simple fast while loop to implement dest += source.
- TC_ATTRIBUTE_ALWAYS_INLINE static void AddImpl(const float *__restrict source,
- uint32 size,
- float *__restrict dest) {
- for (uint32 i = 0; i < size; ++i) {
- dest[i] += source[i];
- }
- }
-
- // Simple fast while loop to implement dest += scale * source.
- TC_ATTRIBUTE_ALWAYS_INLINE static void ScaleAddImpl(
- const float *__restrict source, uint32 size, const float scale,
- float *__restrict dest) {
- for (uint32 i = 0; i < size; ++i) {
- dest[i] += source[i] * scale;
- }
- }
-
- private:
- float *dest_;
- int num_floats_;
-};
-
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_COMMON_SIMPLE_ADDER_H_
diff --git a/common/task-context.cc b/common/task-context.cc
deleted file mode 100644
index e4c1090..0000000
--- a/common/task-context.cc
+++ /dev/null
@@ -1,206 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "common/task-context.h"
-
-#include <stdlib.h>
-
-#include <string>
-
-#include "util/base/integral_types.h"
-#include "util/base/logging.h"
-#include "util/strings/numbers.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-namespace {
-int32 ParseInt32WithDefault(const std::string &s, int32 defval) {
- int32 value = defval;
- return ParseInt32(s.c_str(), &value) ? value : defval;
-}
-
-int64 ParseInt64WithDefault(const std::string &s, int64 defval) {
- int64 value = defval;
- return ParseInt64(s.c_str(), &value) ? value : defval;
-}
-
-double ParseDoubleWithDefault(const std::string &s, double defval) {
- double value = defval;
- return ParseDouble(s.c_str(), &value) ? value : defval;
-}
-} // namespace
-
-TaskInput *TaskContext::GetInput(const std::string &name) {
- // Return existing input if it exists.
- for (int i = 0; i < spec_.input_size(); ++i) {
- if (spec_.input(i).name() == name) return spec_.mutable_input(i);
- }
-
- // Create new input.
- TaskInput *input = spec_.add_input();
- input->set_name(name);
- return input;
-}
-
-TaskInput *TaskContext::GetInput(const std::string &name,
- const std::string &file_format,
- const std::string &record_format) {
- TaskInput *input = GetInput(name);
- if (!file_format.empty()) {
- bool found = false;
- for (int i = 0; i < input->file_format_size(); ++i) {
- if (input->file_format(i) == file_format) found = true;
- }
- if (!found) input->add_file_format(file_format);
- }
- if (!record_format.empty()) {
- bool found = false;
- for (int i = 0; i < input->record_format_size(); ++i) {
- if (input->record_format(i) == record_format) found = true;
- }
- if (!found) input->add_record_format(record_format);
- }
- return input;
-}
-
-void TaskContext::SetParameter(const std::string &name,
- const std::string &value) {
- TC_LOG(INFO) << "SetParameter(" << name << ", " << value << ")";
-
- // If the parameter already exists update the value.
- for (int i = 0; i < spec_.parameter_size(); ++i) {
- if (spec_.parameter(i).name() == name) {
- spec_.mutable_parameter(i)->set_value(value);
- return;
- }
- }
-
- // Add new parameter.
- TaskSpec::Parameter *param = spec_.add_parameter();
- param->set_name(name);
- param->set_value(value);
-}
-
-std::string TaskContext::GetParameter(const std::string &name) const {
- // First try to find parameter in task specification.
- for (int i = 0; i < spec_.parameter_size(); ++i) {
- if (spec_.parameter(i).name() == name) return spec_.parameter(i).value();
- }
-
- // Parameter not found, return empty std::string.
- return "";
-}
-
-int TaskContext::GetIntParameter(const std::string &name) const {
- std::string value = GetParameter(name);
- return ParseInt32WithDefault(value, 0);
-}
-
-int64 TaskContext::GetInt64Parameter(const std::string &name) const {
- std::string value = GetParameter(name);
- return ParseInt64WithDefault(value, 0);
-}
-
-bool TaskContext::GetBoolParameter(const std::string &name) const {
- std::string value = GetParameter(name);
- return value == "true";
-}
-
-double TaskContext::GetFloatParameter(const std::string &name) const {
- std::string value = GetParameter(name);
- return ParseDoubleWithDefault(value, 0.0);
-}
-
-std::string TaskContext::Get(const std::string &name,
- const char *defval) const {
- // First try to find parameter in task specification.
- for (int i = 0; i < spec_.parameter_size(); ++i) {
- if (spec_.parameter(i).name() == name) return spec_.parameter(i).value();
- }
-
- // Parameter not found, return default value.
- return defval;
-}
-
-std::string TaskContext::Get(const std::string &name,
- const std::string &defval) const {
- return Get(name, defval.c_str());
-}
-
-int TaskContext::Get(const std::string &name, int defval) const {
- std::string value = Get(name, "");
- return ParseInt32WithDefault(value, defval);
-}
-
-int64 TaskContext::Get(const std::string &name, int64 defval) const {
- std::string value = Get(name, "");
- return ParseInt64WithDefault(value, defval);
-}
-
-double TaskContext::Get(const std::string &name, double defval) const {
- std::string value = Get(name, "");
- return ParseDoubleWithDefault(value, defval);
-}
-
-bool TaskContext::Get(const std::string &name, bool defval) const {
- std::string value = Get(name, "");
- return value.empty() ? defval : value == "true";
-}
-
-std::string TaskContext::InputFile(const TaskInput &input) {
- if (input.part_size() == 0) {
- TC_LOG(ERROR) << "No file for TaskInput " << input.name();
- return "";
- }
- if (input.part_size() > 1) {
- TC_LOG(ERROR) << "Ambiguous: multiple files for TaskInput " << input.name();
- }
- return input.part(0).file_pattern();
-}
-
-bool TaskContext::Supports(const TaskInput &input,
- const std::string &file_format,
- const std::string &record_format) {
- // Check file format.
- if (input.file_format_size() > 0) {
- bool found = false;
- for (int i = 0; i < input.file_format_size(); ++i) {
- if (input.file_format(i) == file_format) {
- found = true;
- break;
- }
- }
- if (!found) return false;
- }
-
- // Check record format.
- if (input.record_format_size() > 0) {
- bool found = false;
- for (int i = 0; i < input.record_format_size(); ++i) {
- if (input.record_format(i) == record_format) {
- found = true;
- break;
- }
- }
- if (!found) return false;
- }
-
- return true;
-}
-
-} // namespace nlp_core
-} // namespace libtextclassifier
diff --git a/common/task-context.h b/common/task-context.h
deleted file mode 100644
index c55ed67..0000000
--- a/common/task-context.h
+++ /dev/null
@@ -1,87 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_COMMON_TASK_CONTEXT_H_
-#define LIBTEXTCLASSIFIER_COMMON_TASK_CONTEXT_H_
-
-#include <string>
-#include <vector>
-
-#include "common/task-spec.pb.h"
-#include "util/base/integral_types.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-// A task context holds configuration information for a task. It is basically a
-// wrapper around a TaskSpec protocol buffer.
-class TaskContext {
- public:
- // Returns the underlying task specification protocol buffer for the context.
- const TaskSpec &spec() const { return spec_; }
- TaskSpec *mutable_spec() { return &spec_; }
-
- // Returns a named input descriptor for the task. A new input is created if
- // the task context does not already have an input with that name.
- TaskInput *GetInput(const std::string &name);
- TaskInput *GetInput(const std::string &name,
- const std::string &file_format,
- const std::string &record_format);
-
- // Sets task parameter.
- void SetParameter(const std::string &name, const std::string &value);
-
- // Returns task parameter. If the parameter is not in the task configuration
- // the (default) value of the corresponding command line flag is returned.
- std::string GetParameter(const std::string &name) const;
- int GetIntParameter(const std::string &name) const;
- int64 GetInt64Parameter(const std::string &name) const;
- bool GetBoolParameter(const std::string &name) const;
- double GetFloatParameter(const std::string &name) const;
-
- // Returns task parameter. If the parameter is not in the task configuration
- // the default value is returned.
- std::string Get(const std::string &name, const std::string &defval) const;
- std::string Get(const std::string &name, const char *defval) const;
- int Get(const std::string &name, int defval) const;
- int64 Get(const std::string &name, int64 defval) const;
- double Get(const std::string &name, double defval) const;
- bool Get(const std::string &name, bool defval) const;
-
- // Returns input file name for a single-file task input.
- //
- // Special cases: returns the empty string if the TaskInput does not have any
- // input files. Returns the first file if the TaskInput has multiple input
- // files.
- static std::string InputFile(const TaskInput &input);
-
- // Returns true if task input supports the file and record format.
- static bool Supports(const TaskInput &input, const std::string &file_format,
- const std::string &record_format);
-
- private:
- // Underlying task specification protocol buffer.
- TaskSpec spec_;
-
- // Vector of parameters required by this task. These must be specified in the
- // task rather than relying on default values.
- std::vector<std::string> required_parameters_;
-};
-
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_COMMON_TASK_CONTEXT_H_
diff --git a/common/task-spec.proto b/common/task-spec.proto
deleted file mode 100644
index ab986ce..0000000
--- a/common/task-spec.proto
+++ /dev/null
@@ -1,61 +0,0 @@
-// Copyright (C) 2017 The Android Open Source Project
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// LINT: ALLOW_GROUPS
-// Protocol buffer specifications for task configuration.
-
-syntax = "proto2";
-option optimize_for = LITE_RUNTIME;
-
-package libtextclassifier.nlp_core;
-
-// Task input descriptor.
-message TaskInput {
- // Name of input resource.
- required string name = 1;
-
- // File format for resource.
- repeated string file_format = 3;
-
- // Record format for resource.
- repeated string record_format = 4;
-
- // An input can consist of multiple file sets.
- repeated group Part = 6 {
- // File pattern for file set.
- optional string file_pattern = 7;
-
- // File format for file set.
- optional string file_format = 8;
-
- // Record format for file set.
- optional string record_format = 9;
- }
-
- reserved 2, 5;
-}
-
-// A task specification is used for describing executing parameters.
-message TaskSpec {
- // Task parameters.
- repeated group Parameter = 3 {
- required string name = 4;
- optional string value = 5;
- }
-
- // Task inputs.
- repeated TaskInput input = 6;
-
- reserved 1, 2, 7;
-}
diff --git a/common/vector-span.h b/common/vector-span.h
deleted file mode 100644
index d7fbfe9..0000000
--- a/common/vector-span.h
+++ /dev/null
@@ -1,50 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_COMMON_VECTOR_SPAN_H_
-#define LIBTEXTCLASSIFIER_COMMON_VECTOR_SPAN_H_
-
-#include <vector>
-
-namespace libtextclassifier {
-
-// StringPiece analogue for std::vector<T>.
-template <class T>
-class VectorSpan {
- public:
- VectorSpan() : begin_(), end_() {}
- VectorSpan(const std::vector<T>& v) // NOLINT(runtime/explicit)
- : begin_(v.begin()), end_(v.end()) {}
- VectorSpan(typename std::vector<T>::const_iterator begin,
- typename std::vector<T>::const_iterator end)
- : begin_(begin), end_(end) {}
-
- const T& operator[](typename std::vector<T>::size_type i) const {
- return *(begin_ + i);
- }
-
- int size() const { return end_ - begin_; }
- typename std::vector<T>::const_iterator begin() const { return begin_; }
- typename std::vector<T>::const_iterator end() const { return end_; }
-
- private:
- typename std::vector<T>::const_iterator begin_;
- typename std::vector<T>::const_iterator end_;
-};
-
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_COMMON_VECTOR_SPAN_H_
diff --git a/common/workspace.cc b/common/workspace.cc
deleted file mode 100644
index 770e4be..0000000
--- a/common/workspace.cc
+++ /dev/null
@@ -1,63 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "common/workspace.h"
-
-#include <atomic>
-#include <string>
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-// static
-int GetFreshTypeId() {
- // Static local below is initialized the first time this method is run.
- static std::atomic<int> counter(0);
- return counter++;
-}
-
-std::string WorkspaceRegistry::DebugString() const {
- std::string str;
- for (auto &it : workspace_names_) {
- const std::string &type_name = workspace_types_.at(it.first);
- for (size_t index = 0; index < it.second.size(); ++index) {
- const std::string &workspace_name = it.second[index];
- str.append("\n ");
- str.append(type_name);
- str.append(" :: ");
- str.append(workspace_name);
- }
- }
- return str;
-}
-
-VectorIntWorkspace::VectorIntWorkspace(int size) : elements_(size) {}
-
-VectorIntWorkspace::VectorIntWorkspace(int size, int value)
- : elements_(size, value) {}
-
-VectorIntWorkspace::VectorIntWorkspace(const std::vector<int> &elements)
- : elements_(elements) {}
-
-std::string VectorIntWorkspace::TypeName() { return "Vector"; }
-
-VectorVectorIntWorkspace::VectorVectorIntWorkspace(int size)
- : elements_(size) {}
-
-std::string VectorVectorIntWorkspace::TypeName() { return "VectorVector"; }
-
-} // namespace nlp_core
-} // namespace libtextclassifier
diff --git a/common/workspace.h b/common/workspace.h
deleted file mode 100644
index e003bde..0000000
--- a/common/workspace.h
+++ /dev/null
@@ -1,245 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Notes on thread-safety: All of the classes here are thread-compatible. More
-// specifically, the registry machinery is thread-safe, as long as each thread
-// performs feature extraction on a different Sentence object.
-
-#ifndef LIBTEXTCLASSIFIER_COMMON_WORKSPACE_H_
-#define LIBTEXTCLASSIFIER_COMMON_WORKSPACE_H_
-
-#include <stddef.h>
-#include <string>
-#include <unordered_map>
-#include <utility>
-#include <vector>
-
-#include "util/base/logging.h"
-#include "util/base/macros.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-// A base class for shared workspaces. Derived classes implement a static member
-// function TypeName() which returns a human readable std::string name for the
-// class.
-class Workspace {
- public:
- // Polymorphic destructor.
- virtual ~Workspace() {}
-
- protected:
- // Create an empty workspace.
- Workspace() {}
-
- private:
- TC_DISALLOW_COPY_AND_ASSIGN(Workspace);
-};
-
-// Returns a new, strictly increasing int every time it is invoked.
-int GetFreshTypeId();
-
-// Struct to simulate typeid, but without RTTI.
-template <typename T>
-struct TypeId {
- static int type_id;
-};
-
-template <typename T>
-int TypeId<T>::type_id = GetFreshTypeId();
-
-// A registry that keeps track of workspaces.
-class WorkspaceRegistry {
- public:
- // Create an empty registry.
- WorkspaceRegistry() {}
-
- // Returns the index of a named workspace, adding it to the registry first
- // if necessary.
- template <class W>
- int Request(const std::string &name) {
- const int id = TypeId<W>::type_id;
- max_workspace_id_ = std::max(id, max_workspace_id_);
- workspace_types_[id] = W::TypeName();
- std::vector<std::string> &names = workspace_names_[id];
- for (int i = 0; i < names.size(); ++i) {
- if (names[i] == name) return i;
- }
- names.push_back(name);
- return names.size() - 1;
- }
-
- // Returns the maximum workspace id that has been registered.
- int MaxId() const {
- return max_workspace_id_;
- }
-
- const std::unordered_map<int, std::vector<std::string> > &WorkspaceNames()
- const {
- return workspace_names_;
- }
-
- // Returns a std::string describing the registered workspaces.
- std::string DebugString() const;
-
- private:
- // Workspace type names, indexed as workspace_types_[typeid].
- std::unordered_map<int, std::string> workspace_types_;
-
- // Workspace names, indexed as workspace_names_[typeid][workspace].
- std::unordered_map<int, std::vector<std::string> > workspace_names_;
-
- // The maximum workspace id that has been registered.
- int max_workspace_id_ = 0;
-
- TC_DISALLOW_COPY_AND_ASSIGN(WorkspaceRegistry);
-};
-
-// A typed collected of workspaces. The workspaces are indexed according to an
-// external WorkspaceRegistry. If the WorkspaceSet is const, the contents are
-// also immutable.
-class WorkspaceSet {
- public:
- ~WorkspaceSet() { Reset(WorkspaceRegistry()); }
-
- // Returns true if a workspace has been set.
- template <class W>
- bool Has(int index) const {
- const int id = TypeId<W>::type_id;
- TC_DCHECK_GE(id, 0);
- TC_DCHECK_LT(id, workspaces_.size());
- TC_DCHECK_GE(index, 0);
- TC_DCHECK_LT(index, workspaces_[id].size());
- if (id >= workspaces_.size()) return false;
- return workspaces_[id][index] != nullptr;
- }
-
- // Returns an indexed workspace; the workspace must have been set.
- template <class W>
- const W &Get(int index) const {
- TC_DCHECK(Has<W>(index));
- const int id = TypeId<W>::type_id;
- const Workspace *w = workspaces_[id][index];
- return reinterpret_cast<const W &>(*w);
- }
-
- // Sets an indexed workspace; this takes ownership of the workspace, which
- // must have been new-allocated. It is an error to set a workspace twice.
- template <class W>
- void Set(int index, W *workspace) {
- const int id = TypeId<W>::type_id;
- TC_DCHECK_GE(id, 0);
- TC_DCHECK_LT(id, workspaces_.size());
- TC_DCHECK_GE(index, 0);
- TC_DCHECK_LT(index, workspaces_[id].size());
- TC_DCHECK(workspaces_[id][index] == nullptr);
- TC_DCHECK(workspace != nullptr);
- workspaces_[id][index] = workspace;
- }
-
- void Reset(const WorkspaceRegistry ®istry) {
- // Deallocate current workspaces.
- for (auto &it : workspaces_) {
- for (size_t index = 0; index < it.size(); ++index) {
- delete it[index];
- }
- }
- workspaces_.clear();
- workspaces_.resize(registry.MaxId() + 1, std::vector<Workspace *>());
- for (auto &it : registry.WorkspaceNames()) {
- workspaces_[it.first].resize(it.second.size());
- }
- }
-
- private:
- // The set of workspaces, indexed as workspaces_[typeid][index].
- std::vector<std::vector<Workspace *> > workspaces_;
-};
-
-// A workspace that wraps around a single int.
-class SingletonIntWorkspace : public Workspace {
- public:
- // Default-initializes the int value.
- SingletonIntWorkspace() {}
-
- // Initializes the int with the given value.
- explicit SingletonIntWorkspace(int value) : value_(value) {}
-
- // Returns the name of this type of workspace.
- static std::string TypeName() { return "SingletonInt"; }
-
- // Returns the int value.
- int get() const { return value_; }
-
- // Sets the int value.
- void set(int value) { value_ = value; }
-
- private:
- // The enclosed int.
- int value_ = 0;
-};
-
-// A workspace that wraps around a vector of int.
-class VectorIntWorkspace : public Workspace {
- public:
- // Creates a vector of the given size.
- explicit VectorIntWorkspace(int size);
-
- // Creates a vector initialized with the given array.
- explicit VectorIntWorkspace(const std::vector<int> &elements);
-
- // Creates a vector of the given size, with each element initialized to the
- // given value.
- VectorIntWorkspace(int size, int value);
-
- // Returns the name of this type of workspace.
- static std::string TypeName();
-
- // Returns the i'th element.
- int element(int i) const { return elements_[i]; }
-
- // Sets the i'th element.
- void set_element(int i, int value) { elements_[i] = value; }
-
- private:
- // The enclosed vector.
- std::vector<int> elements_;
-};
-
-// A workspace that wraps around a vector of vector of int.
-class VectorVectorIntWorkspace : public Workspace {
- public:
- // Creates a vector of empty vectors of the given size.
- explicit VectorVectorIntWorkspace(int size);
-
- // Returns the name of this type of workspace.
- static std::string TypeName();
-
- // Returns the i'th vector of elements.
- const std::vector<int> &elements(int i) const { return elements_[i]; }
-
- // Mutable access to the i'th vector of elements.
- std::vector<int> *mutable_elements(int i) { return &(elements_[i]); }
-
- private:
- // The enclosed vector of vector of elements.
- std::vector<std::vector<int> > elements_;
-};
-
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_COMMON_WORKSPACE_H_
diff --git a/smartselect/feature-processor.cc b/feature-processor.cc
similarity index 76%
rename from smartselect/feature-processor.cc
rename to feature-processor.cc
index c1db95a..c607b13 100644
--- a/smartselect/feature-processor.cc
+++ b/feature-processor.cc
@@ -14,59 +14,51 @@
* limitations under the License.
*/
-#include "smartselect/feature-processor.h"
+#include "feature-processor.h"
#include <iterator>
#include <set>
#include <vector>
-#include "smartselect/text-classification-model.pb.h"
#include "util/base/logging.h"
#include "util/strings/utf8.h"
#include "util/utf8/unicodetext.h"
-#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
-#include "unicode/brkiter.h"
-#include "unicode/errorcode.h"
-#include "unicode/uchar.h"
-#endif
-namespace libtextclassifier {
+namespace libtextclassifier2 {
namespace internal {
TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
- const FeatureProcessorOptions& options) {
+ const FeatureProcessorOptions* const options) {
TokenFeatureExtractorOptions extractor_options;
- extractor_options.num_buckets = options.num_buckets();
- for (int order : options.chargram_orders()) {
- extractor_options.chargram_orders.push_back(order);
+ extractor_options.num_buckets = options->num_buckets();
+ if (options->chargram_orders() != nullptr) {
+ for (int order : *options->chargram_orders()) {
+ extractor_options.chargram_orders.push_back(order);
+ }
}
- extractor_options.max_word_length = options.max_word_length();
- extractor_options.extract_case_feature = options.extract_case_feature();
- extractor_options.unicode_aware_features = options.unicode_aware_features();
+ extractor_options.max_word_length = options->max_word_length();
+ extractor_options.extract_case_feature = options->extract_case_feature();
+ extractor_options.unicode_aware_features = options->unicode_aware_features();
extractor_options.extract_selection_mask_feature =
- options.extract_selection_mask_feature();
- for (int i = 0; i < options.regexp_feature_size(); ++i) {
- extractor_options.regexp_features.push_back(options.regexp_feature(i));
+ options->extract_selection_mask_feature();
+ if (options->regexp_feature() != nullptr) {
+ for (const auto& regexp_feauture : *options->regexp_feature()) {
+ extractor_options.regexp_features.push_back(regexp_feauture->str());
+ }
}
- extractor_options.remap_digits = options.remap_digits();
- extractor_options.lowercase_tokens = options.lowercase_tokens();
+ extractor_options.remap_digits = options->remap_digits();
+ extractor_options.lowercase_tokens = options->lowercase_tokens();
- for (const auto& chargram : options.allowed_chargrams()) {
- extractor_options.allowed_chargrams.insert(chargram);
+ if (options->allowed_chargrams() != nullptr) {
+ for (const auto& chargram : *options->allowed_chargrams()) {
+ extractor_options.allowed_chargrams.insert(chargram->str());
+ }
}
-
return extractor_options;
}
-FeatureProcessorOptions ParseSerializedOptions(
- const std::string& serialized_options) {
- FeatureProcessorOptions options;
- options.ParseFromString(serialized_options);
- return options;
-}
-
void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
std::vector<Token>* tokens) {
for (auto it = tokens->begin(); it != tokens->end(); ++it) {
@@ -119,6 +111,16 @@
}
}
+UniLib* MaybeCreateUnilib(UniLib* unilib,
+ std::unique_ptr<UniLib>* owned_unilib) {
+ if (unilib) {
+ return unilib;
+ } else {
+ owned_unilib->reset(new UniLib);
+ return owned_unilib->get();
+ }
+}
+
} // namespace internal
void FeatureProcessor::StripTokensFromOtherLines(
@@ -157,30 +159,30 @@
}
std::string FeatureProcessor::GetDefaultCollection() const {
- if (options_.default_collection() < 0 ||
- options_.default_collection() >= options_.collections_size()) {
+ if (options_->default_collection() < 0 ||
+ options_->default_collection() >= options_->collections()->size()) {
TC_LOG(ERROR)
<< "Invalid or missing default collection. Returning empty string.";
return "";
}
- return options_.collections(options_.default_collection());
+ return (*options_->collections())[options_->default_collection()]->str();
}
std::vector<Token> FeatureProcessor::Tokenize(
const std::string& utf8_text) const {
- if (options_.tokenization_type() ==
- libtextclassifier::FeatureProcessorOptions::INTERNAL_TOKENIZER) {
+ if (options_->tokenization_type() ==
+ FeatureProcessorOptions_::TokenizationType_INTERNAL_TOKENIZER) {
return tokenizer_.Tokenize(utf8_text);
- } else if (options_.tokenization_type() ==
- libtextclassifier::FeatureProcessorOptions::ICU ||
- options_.tokenization_type() ==
- libtextclassifier::FeatureProcessorOptions::MIXED) {
+ } else if (options_->tokenization_type() ==
+ FeatureProcessorOptions_::TokenizationType_ICU ||
+ options_->tokenization_type() ==
+ FeatureProcessorOptions_::TokenizationType_MIXED) {
std::vector<Token> result;
if (!ICUTokenize(utf8_text, &result)) {
return {};
}
- if (options_.tokenization_type() ==
- libtextclassifier::FeatureProcessorOptions::MIXED) {
+ if (options_->tokenization_type() ==
+ FeatureProcessorOptions_::TokenizationType_MIXED) {
InternalRetokenize(utf8_text, &result);
}
return result;
@@ -205,11 +207,11 @@
const int result_begin_token_index = token_span.first;
const Token& result_begin_token =
- tokens[options_.context_size() - result_begin_token_index];
+ tokens[options_->context_size() - result_begin_token_index];
const int result_begin_codepoint = result_begin_token.start;
const int result_end_token_index = token_span.second;
const Token& result_end_token =
- tokens[options_.context_size() + result_end_token_index];
+ tokens[options_->context_size() + result_end_token_index];
const int result_end_codepoint = result_end_token.end;
if (result_begin_codepoint == kInvalidIndex ||
@@ -224,9 +226,11 @@
UnicodeText::const_iterator token_end = token_end_unicode.end();
const int begin_ignored = CountIgnoredSpanBoundaryCodepoints(
- token_begin, token_begin_unicode.end(), /*count_from_beginning=*/true);
- const int end_ignored = CountIgnoredSpanBoundaryCodepoints(
- token_end_unicode.begin(), token_end, /*count_from_beginning=*/false);
+ token_begin, token_begin_unicode.end(),
+ /*count_from_beginning=*/true);
+ const int end_ignored =
+ CountIgnoredSpanBoundaryCodepoints(token_end_unicode.begin(), token_end,
+ /*count_from_beginning=*/false);
// In case everything would be stripped, set the span to the original
// beginning and zero length.
if (begin_ignored == (result_end_codepoint - result_begin_codepoint)) {
@@ -257,8 +261,8 @@
}
const int click_position =
- options_.context_size(); // Click is always in the middle.
- const int padding = options_.context_size() - options_.max_selection_span();
+ options_->context_size(); // Click is always in the middle.
+ const int padding = options_->context_size() - options_->max_selection_span();
int span_left = 0;
for (int i = click_position - 1; i >= padding; i--) {
@@ -282,7 +286,7 @@
bool tokens_match_span;
const CodepointIndex tokens_start = tokens[click_position - span_left].start;
const CodepointIndex tokens_end = tokens[click_position + span_right].end;
- if (options_.snap_label_span_boundaries_to_containing_tokens()) {
+ if (options_->snap_label_span_boundaries_to_containing_tokens()) {
tokens_match_span = tokens_start <= span.first && tokens_end >= span.second;
} else {
const UnicodeText token_left_unicode = UTF8ToUnicodeText(
@@ -296,7 +300,8 @@
const int num_punctuation_start = CountIgnoredSpanBoundaryCodepoints(
span_begin, token_left_unicode.end(), /*count_from_beginning=*/true);
const int num_punctuation_end = CountIgnoredSpanBoundaryCodepoints(
- token_right_unicode.begin(), span_end, /*count_from_beginning=*/false);
+ token_right_unicode.begin(), span_end,
+ /*count_from_beginning=*/false);
tokens_match_span = tokens_start <= span.first &&
tokens_start + num_punctuation_start >= span.first &&
@@ -422,19 +427,22 @@
int FeatureProcessor::FindCenterToken(CodepointSpan span,
const std::vector<Token>& tokens) const {
- if (options_.center_token_selection_method() ==
- FeatureProcessorOptions::CENTER_TOKEN_FROM_CLICK) {
+ if (options_->center_token_selection_method() ==
+ FeatureProcessorOptions_::
+ CenterTokenSelectionMethod_CENTER_TOKEN_FROM_CLICK) {
return internal::CenterTokenFromClick(span, tokens);
- } else if (options_.center_token_selection_method() ==
- FeatureProcessorOptions::CENTER_TOKEN_MIDDLE_OF_SELECTION) {
+ } else if (options_->center_token_selection_method() ==
+ FeatureProcessorOptions_::
+ CenterTokenSelectionMethod_CENTER_TOKEN_MIDDLE_OF_SELECTION) {
return internal::CenterTokenFromMiddleOfSelection(span, tokens);
- } else if (options_.center_token_selection_method() ==
- FeatureProcessorOptions::DEFAULT_CENTER_TOKEN_METHOD) {
+ } else if (options_->center_token_selection_method() ==
+ FeatureProcessorOptions_::
+ CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD) {
// TODO(zilka): Remove once we have new models on the device.
// It uses the fact that sharing model use
// split_tokens_on_selection_boundaries and selection not. So depending on
// this we select the right way of finding the click location.
- if (!options_.split_tokens_on_selection_boundaries()) {
+ if (!options_->split_tokens_on_selection_boundaries()) {
// SmartSelection model.
return internal::CenterTokenFromClick(span, tokens);
} else {
@@ -462,15 +470,15 @@
}
void FeatureProcessor::PrepareCodepointRanges(
- const std::vector<FeatureProcessorOptions::CodepointRange>&
+ const std::vector<const FeatureProcessorOptions_::CodepointRange*>&
codepoint_ranges,
std::vector<CodepointRange>* prepared_codepoint_ranges) {
prepared_codepoint_ranges->clear();
prepared_codepoint_ranges->reserve(codepoint_ranges.size());
- for (const FeatureProcessorOptions::CodepointRange& range :
+ for (const FeatureProcessorOptions_::CodepointRange* range :
codepoint_ranges) {
prepared_codepoint_ranges->push_back(
- CodepointRange(range.start(), range.end()));
+ CodepointRange(range->start(), range->end()));
}
std::sort(prepared_codepoint_ranges->begin(),
@@ -481,8 +489,10 @@
}
void FeatureProcessor::PrepareIgnoredSpanBoundaryCodepoints() {
- for (const int codepoint : options_.ignored_span_boundary_codepoints()) {
- ignored_span_boundary_codepoints_.insert(codepoint);
+ if (options_->ignored_span_boundary_codepoints() != nullptr) {
+ for (const int codepoint : *options_->ignored_span_boundary_codepoints()) {
+ ignored_span_boundary_codepoints_.insert(codepoint);
+ }
}
}
@@ -555,7 +565,7 @@
std::vector<UnicodeTextRange> FeatureProcessor::SplitContext(
const UnicodeText& context_unicode) const {
- if (options_.only_use_line_with_click()) {
+ if (options_->only_use_line_with_click()) {
std::vector<UnicodeTextRange> lines;
std::set<char32> codepoints;
codepoints.insert('\n');
@@ -589,21 +599,17 @@
}
float FeatureProcessor::SupportedCodepointsRatio(
- int click_pos, const std::vector<Token>& tokens) const {
+ const TokenSpan& token_span, const std::vector<Token>& tokens) const {
int num_supported = 0;
int num_total = 0;
- for (int i = click_pos - options_.context_size();
- i <= click_pos + options_.context_size(); ++i) {
- const bool is_valid_token = i >= 0 && i < tokens.size();
- if (is_valid_token) {
- const UnicodeText value =
- UTF8ToUnicodeText(tokens[i].value, /*do_copy=*/false);
- for (auto codepoint : value) {
- if (IsCodepointInRanges(codepoint, supported_codepoint_ranges_)) {
- ++num_supported;
- }
- ++num_total;
+ for (int i = token_span.first; i < token_span.second; ++i) {
+ const UnicodeText value =
+ UTF8ToUnicodeText(tokens[i].value, /*do_copy=*/false);
+ for (auto codepoint : value) {
+ if (IsCodepointInRanges(codepoint, supported_codepoint_ranges_)) {
+ ++num_supported;
}
+ ++num_total;
}
}
return static_cast<float>(num_supported) / static_cast<float>(num_total);
@@ -640,7 +646,7 @@
int FeatureProcessor::CollectionToLabel(const std::string& collection) const {
const auto it = collection_to_label_.find(collection);
if (it == collection_to_label_.end()) {
- return options_.default_collection();
+ return options_->default_collection();
} else {
return it->second;
}
@@ -648,22 +654,24 @@
std::string FeatureProcessor::LabelToCollection(int label) const {
if (label >= 0 && label < collection_to_label_.size()) {
- return options_.collections(label);
+ return (*options_->collections())[label]->str();
} else {
return GetDefaultCollection();
}
}
void FeatureProcessor::MakeLabelMaps() {
- for (int i = 0; i < options_.collections().size(); ++i) {
- collection_to_label_[options_.collections(i)] = i;
+ if (options_->collections() != nullptr) {
+ for (int i = 0; i < options_->collections()->size(); ++i) {
+ collection_to_label_[(*options_->collections())[i]->str()] = i;
+ }
}
int selection_label_id = 0;
- for (int l = 0; l < (options_.max_selection_span() + 1); ++l) {
- for (int r = 0; r < (options_.max_selection_span() + 1); ++r) {
- if (!options_.selection_reduced_output_space() ||
- r + l <= options_.max_selection_span()) {
+ for (int l = 0; l < (options_->max_selection_span() + 1); ++l) {
+ for (int r = 0; r < (options_->max_selection_span() + 1); ++r) {
+ if (!options_->selection_reduced_output_space() ||
+ r + l <= options_->max_selection_span()) {
TokenSpan token_span{l, r};
selection_to_label_[token_span] = selection_label_id;
label_to_selection_.push_back(token_span);
@@ -680,11 +688,11 @@
TC_CHECK(tokens != nullptr);
*tokens = Tokenize(context);
- if (options_.split_tokens_on_selection_boundaries()) {
+ if (options_->split_tokens_on_selection_boundaries()) {
internal::SplitTokensOnSelectionBoundaries(input_span, tokens);
}
- if (options_.only_use_line_with_click()) {
+ if (options_->only_use_line_with_click()) {
StripTokensFromOtherLines(context, input_span, tokens);
}
@@ -693,6 +701,11 @@
click_pos = &local_click_pos;
}
*click_pos = FindCenterToken(input_span, *tokens);
+ if (*click_pos == kInvalidIndex) {
+ // If the default click method failed, let's try to do sub-token matching
+ // before we fail.
+ *click_pos = internal::CenterTokenFromClick(input_span, *tokens);
+ }
}
namespace internal {
@@ -734,118 +747,102 @@
} // namespace internal
bool FeatureProcessor::ExtractFeatures(
- const std::string& context, CodepointSpan input_span,
- TokenSpan relative_click_span, const FeatureVectorFn& feature_vector_fn,
- int feature_vector_size, std::vector<Token>* tokens, int* click_pos,
+ const std::vector<Token>& tokens, TokenSpan token_span,
+ EmbeddingExecutor* embedding_executor, int feature_vector_size,
std::unique_ptr<CachedFeatures>* cached_features) const {
- TokenizeAndFindClick(context, input_span, tokens, click_pos);
-
- if (input_span.first != kInvalidIndex && input_span.second != kInvalidIndex) {
- // If the default click method failed, let's try to do sub-token matching
- // before we fail.
- if (*click_pos == kInvalidIndex) {
- *click_pos = internal::CenterTokenFromClick(input_span, *tokens);
- if (*click_pos == kInvalidIndex) {
- return false;
- }
- }
- } else {
- // If input_span is unspecified, click the first token and extract features
- // from all tokens.
- *click_pos = 0;
- relative_click_span = {0, tokens->size()};
+ if (options_->feature_version() < 2) {
+ TC_LOG(ERROR) << "Unsupported feature version.";
+ return false;
+ }
+ if (!options_->bounds_sensitive_features() ||
+ !options_->bounds_sensitive_features()->enabled()) {
+ TC_LOG(ERROR) << "Bounds-sensitive features not enabled.";
+ return false;
}
- internal::StripOrPadTokens(relative_click_span, options_.context_size(),
- tokens, click_pos);
-
- if (options_.min_supported_codepoint_ratio() > 0) {
+ if (options_->min_supported_codepoint_ratio() > 0) {
const float supported_codepoint_ratio =
- SupportedCodepointsRatio(*click_pos, *tokens);
- if (supported_codepoint_ratio < options_.min_supported_codepoint_ratio()) {
+ SupportedCodepointsRatio(token_span, tokens);
+ if (supported_codepoint_ratio < options_->min_supported_codepoint_ratio()) {
TC_VLOG(1) << "Not enough supported codepoints in the context: "
<< supported_codepoint_ratio;
return false;
}
}
- std::vector<std::vector<int>> sparse_features(tokens->size());
- std::vector<std::vector<float>> dense_features(tokens->size());
- for (int i = 0; i < tokens->size(); ++i) {
- const Token& token = (*tokens)[i];
- if (!feature_extractor_.Extract(token, token.IsContainedInSpan(input_span),
- &(sparse_features[i]),
- &(dense_features[i]))) {
+ std::vector<std::vector<int>> sparse_features(TokenSpanSize(token_span));
+ std::vector<std::vector<float>> dense_features(TokenSpanSize(token_span));
+ for (int i = token_span.first; i < token_span.second; ++i) {
+ const Token& token = tokens[i];
+ const int features_index = i - token_span.first;
+ if (!feature_extractor_.Extract(token, false,
+ &(sparse_features[features_index]),
+ &(dense_features[features_index]))) {
TC_LOG(ERROR) << "Could not extract token's features: " << token;
return false;
}
}
- cached_features->reset(new CachedFeatures(
- *tokens, options_.context_size(), sparse_features, dense_features,
- feature_vector_fn, feature_vector_size));
-
- if (*cached_features == nullptr) {
+ std::vector<int> padding_sparse_features;
+ std::vector<float> padding_dense_features;
+ if (!feature_extractor_.Extract(Token(), false, &padding_sparse_features,
+ &padding_dense_features)) {
+ TC_LOG(ERROR) << "Could not extract padding token's features.";
return false;
}
- if (options_.feature_version() == 0) {
- (*cached_features)
- ->SetV0FeatureMode(feature_vector_size -
- feature_extractor_.DenseFeaturesCount());
- }
+ cached_features->reset(new CachedFeatures(
+ token_span, sparse_features, dense_features, padding_sparse_features,
+ padding_dense_features, options_->bounds_sensitive_features(),
+ embedding_executor, feature_vector_size));
return true;
}
bool FeatureProcessor::ICUTokenize(const std::string& context,
std::vector<Token>* result) const {
-#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
- icu::ErrorCode status;
- icu::UnicodeString unicode_text = icu::UnicodeString::fromUTF8(context);
- std::unique_ptr<icu::BreakIterator> break_iterator(
- icu::BreakIterator::createWordInstance(icu::Locale("en"), status));
- if (!status.isSuccess()) {
- TC_LOG(ERROR) << "Break iterator did not initialize properly: "
- << status.errorName();
+ std::unique_ptr<UniLib::BreakIterator> break_iterator =
+ unilib_->CreateBreakIterator(context);
+ if (!break_iterator) {
return false;
}
- break_iterator->setText(unicode_text);
-
- size_t last_break_index = 0;
- size_t break_index = 0;
- size_t last_unicode_index = 0;
- size_t unicode_index = 0;
- while ((break_index = break_iterator->next()) != icu::BreakIterator::DONE) {
- icu::UnicodeString token(unicode_text, last_break_index,
- break_index - last_break_index);
- int token_length = token.countChar32();
+ UnicodeText context_unicode = UTF8ToUnicodeText(context, /*do_copy=*/false);
+ int last_break_index = 0;
+ int break_index = 0;
+ int last_unicode_index = 0;
+ int unicode_index = 0;
+ auto token_begin_it = context_unicode.begin();
+ while ((break_index = break_iterator->Next()) !=
+ UniLib::BreakIterator::kDone) {
+ const int token_length = break_index - last_break_index;
unicode_index = last_unicode_index + token_length;
- std::string token_utf8;
- token.toUTF8String(token_utf8);
+ auto token_end_it = token_begin_it;
+ std::advance(token_end_it, token_length);
+ // Determine if the whole token is whitespace.
bool is_whitespace = true;
- for (int i = 0; i < token.length(); i++) {
- if (!u_isWhitespace(token.char32At(i))) {
+ for (auto char_it = token_begin_it; char_it < token_end_it; ++char_it) {
+ if (!unilib_->IsWhitespace(*char_it)) {
is_whitespace = false;
+ break;
}
}
- if (!is_whitespace || options_.icu_preserve_whitespace_tokens()) {
- result->push_back(Token(token_utf8, last_unicode_index, unicode_index));
+ const std::string token =
+ context_unicode.UTF8Substring(token_begin_it, token_end_it);
+
+ if (!is_whitespace || options_->icu_preserve_whitespace_tokens()) {
+ result->push_back(Token(token, last_unicode_index, unicode_index));
}
last_break_index = break_index;
last_unicode_index = unicode_index;
+ token_begin_it = token_end_it;
}
return true;
-#else
- TC_LOG(WARNING) << "Can't tokenize, ICU not supported";
- return false;
-#endif
}
void FeatureProcessor::InternalRetokenize(const std::string& context,
@@ -914,4 +911,4 @@
}
}
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
diff --git a/smartselect/feature-processor.h b/feature-processor.h
similarity index 76%
rename from smartselect/feature-processor.h
rename to feature-processor.h
index ef9a3df..834c260 100644
--- a/smartselect/feature-processor.h
+++ b/feature-processor.h
@@ -16,42 +16,33 @@
// Feature processing for FFModel (feed-forward SmartSelection model).
-#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_
-#define LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_FEATURE_PROCESSOR_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_FEATURE_PROCESSOR_H_
+#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
-#include "smartselect/cached-features.h"
-#include "smartselect/text-classification-model.pb.h"
-#include "smartselect/token-feature-extractor.h"
-#include "smartselect/tokenizer.h"
-#include "smartselect/types.h"
+#include "cached-features.h"
+#include "model_generated.h"
+#include "token-feature-extractor.h"
+#include "tokenizer.h"
+#include "types.h"
+#include "util/base/integral_types.h"
#include "util/base/logging.h"
#include "util/utf8/unicodetext.h"
+#include "util/utf8/unilib.h"
-namespace libtextclassifier {
+namespace libtextclassifier2 {
constexpr int kInvalidLabel = -1;
-// Maps a vector of sparse features and a vector of dense features to a vector
-// of features that combines both.
-// The output is written to the memory location pointed to by the last float*
-// argument.
-// Returns true on success false on failure.
-using FeatureVectorFn = std::function<bool(const std::vector<int>&,
- const std::vector<float>&, float*)>;
-
namespace internal {
-// Parses the serialized protocol buffer.
-FeatureProcessorOptions ParseSerializedOptions(
- const std::string& serialized_options);
-
TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
- const FeatureProcessorOptions& options);
+ const FeatureProcessorOptions* options);
// Splits tokens that contain the selection boundary inside them.
// E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
@@ -73,6 +64,11 @@
void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
std::vector<Token>* tokens, int* click_pos);
+// If unilib is not nullptr, just returns unilib. Otherwise, if unilib is
+// nullptr, will create UniLib, assign ownership to owned_unilib, and return it.
+UniLib* MaybeCreateUnilib(UniLib* unilib,
+ std::unique_ptr<UniLib>* owned_unilib);
+
} // namespace internal
// Converts a codepoint span to a token span in the given list of tokens.
@@ -90,27 +86,36 @@
// Takes care of preparing features for the span prediction model.
class FeatureProcessor {
public:
- explicit FeatureProcessor(const FeatureProcessorOptions& options)
- : feature_extractor_(
- internal::BuildTokenFeatureExtractorOptions(options)),
+ // If unilib is nullptr, will create and own an instance of a UniLib,
+ // otherwise will use what's passed in.
+ explicit FeatureProcessor(const FeatureProcessorOptions* options,
+ UniLib* unilib = nullptr)
+ : owned_unilib_(nullptr),
+ unilib_(internal::MaybeCreateUnilib(unilib, &owned_unilib_)),
+ feature_extractor_(internal::BuildTokenFeatureExtractorOptions(options),
+ *unilib_),
options_(options),
- tokenizer_({options.tokenization_codepoint_config().begin(),
- options.tokenization_codepoint_config().end()}) {
+ tokenizer_(
+ options->tokenization_codepoint_config() != nullptr
+ ? Tokenizer({options->tokenization_codepoint_config()->begin(),
+ options->tokenization_codepoint_config()->end()},
+ options->tokenize_on_script_change())
+ : Tokenizer({}, /*split_on_script_change=*/false)) {
MakeLabelMaps();
- PrepareCodepointRanges({options.supported_codepoint_ranges().begin(),
- options.supported_codepoint_ranges().end()},
- &supported_codepoint_ranges_);
- PrepareCodepointRanges(
- {options.internal_tokenizer_codepoint_ranges().begin(),
- options.internal_tokenizer_codepoint_ranges().end()},
- &internal_tokenizer_codepoint_ranges_);
+ if (options->supported_codepoint_ranges() != nullptr) {
+ PrepareCodepointRanges({options->supported_codepoint_ranges()->begin(),
+ options->supported_codepoint_ranges()->end()},
+ &supported_codepoint_ranges_);
+ }
+ if (options->internal_tokenizer_codepoint_ranges() != nullptr) {
+ PrepareCodepointRanges(
+ {options->internal_tokenizer_codepoint_ranges()->begin(),
+ options->internal_tokenizer_codepoint_ranges()->end()},
+ &internal_tokenizer_codepoint_ranges_);
+ }
PrepareIgnoredSpanBoundaryCodepoints();
}
- explicit FeatureProcessor(const std::string& serialized_options)
- : FeatureProcessor(internal::ParseSerializedOptions(serialized_options)) {
- }
-
// Tokenizes the input string using the selected tokenization method.
std::vector<Token> Tokenize(const std::string& utf8_text) const;
@@ -129,7 +134,7 @@
// Gets the name of the default collection.
std::string GetDefaultCollection() const;
- const FeatureProcessorOptions& GetOptions() const { return options_; }
+ const FeatureProcessorOptions* GetOptions() const { return options_; }
// Tokenizes the context and input span, and finds the click position.
void TokenizeAndFindClick(const std::string& context,
@@ -138,13 +143,9 @@
// Extracts features as a CachedFeatures object that can be used for repeated
// inference over token spans in the given context.
- // When input_span == {kInvalidIndex, kInvalidIndex} then, relative_click_span
- // is ignored, and all tokens extracted from context will be considered.
- bool ExtractFeatures(const std::string& context, CodepointSpan input_span,
- TokenSpan relative_click_span,
- const FeatureVectorFn& feature_vector_fn,
- int feature_vector_size, std::vector<Token>* tokens,
- int* click_pos,
+ bool ExtractFeatures(const std::vector<Token>& tokens, TokenSpan token_span,
+ EmbeddingExecutor* embedding_executor,
+ int feature_vector_size,
std::unique_ptr<CachedFeatures>* cached_features) const;
// Fills selection_label_spans with CodepointSpans that correspond to the
@@ -158,6 +159,8 @@
return feature_extractor_.DenseFeaturesCount();
}
+ int EmbeddingSize() const { return options_->embedding_size(); }
+
// Splits context to several segments according to configuration.
std::vector<UnicodeTextRange> SplitContext(
const UnicodeText& context_unicode) const;
@@ -191,7 +194,7 @@
// Spannable tokens are those tokens of context, which the model predicts
// selection spans over (i.e., there is 1:1 correspondence between the output
// classes of the model and each of the spannable tokens).
- int GetNumContextTokens() const { return options_.context_size() * 2 + 1; }
+ int GetNumContextTokens() const { return options_->context_size() * 2 + 1; }
// Converts a label into a span of codepoint indices corresponding to it
// given output_tokens.
@@ -206,13 +209,13 @@
int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const;
void PrepareCodepointRanges(
- const std::vector<FeatureProcessorOptions::CodepointRange>&
+ const std::vector<const FeatureProcessorOptions_::CodepointRange*>&
codepoint_ranges,
std::vector<CodepointRange>* prepared_codepoint_ranges);
// Returns the ratio of supported codepoints to total number of codepoints in
- // the input context around given click position.
- float SupportedCodepointsRatio(int click_pos,
+ // the given token span.
+ float SupportedCodepointsRatio(const TokenSpan& token_span,
const std::vector<Token>& tokens) const;
// Returns true if given codepoint is covered by the given sorted vector of
@@ -257,6 +260,11 @@
void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
std::vector<Token>* tokens) const;
+ private:
+ std::unique_ptr<UniLib> owned_unilib_;
+ UniLib* unilib_;
+
+ protected:
const TokenFeatureExtractor feature_extractor_;
// Codepoint ranges that define what codepoints are supported by the model.
@@ -274,7 +282,7 @@
// predicted spans.
std::set<int32> ignored_span_boundary_codepoints_;
- const FeatureProcessorOptions options_;
+ const FeatureProcessorOptions* const options_;
// Mapping between token selection spans and labels ids.
std::map<TokenSpan, int> selection_to_label_;
@@ -286,6 +294,6 @@
Tokenizer tokenizer_;
};
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
-#endif // LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_FEATURE_PROCESSOR_H_
diff --git a/smartselect/feature-processor_test.cc b/feature-processor_test.cc
similarity index 69%
rename from smartselect/feature-processor_test.cc
rename to feature-processor_test.cc
index 9bee67a..5af8b96 100644
--- a/smartselect/feature-processor_test.cc
+++ b/feature-processor_test.cc
@@ -14,17 +14,27 @@
* limitations under the License.
*/
-#include "smartselect/feature-processor.h"
+#include "feature-processor.h"
+
+#include "model-executor.h"
+#include "tensor-view.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
-namespace libtextclassifier {
+namespace libtextclassifier2 {
namespace {
using testing::ElementsAreArray;
using testing::FloatEq;
+flatbuffers::DetachedBuffer PackFeatureProcessorOptions(
+ const FeatureProcessorOptionsT& options) {
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(CreateFeatureProcessorOptions(builder, &options));
+ return builder.Release();
+}
+
class TestingFeatureProcessor : public FeatureProcessor {
public:
using FeatureProcessor::CountIgnoredSpanBoundaryCodepoints;
@@ -37,6 +47,24 @@
using FeatureProcessor::SupportedCodepointsRatio;
};
+// EmbeddingExecutor that always returns features based on
+class FakeEmbeddingExecutor : public EmbeddingExecutor {
+ public:
+ bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
+ int dest_size) override {
+ TC_CHECK_GE(dest_size, 4);
+ EXPECT_EQ(sparse_features.size(), 1);
+ dest[0] = sparse_features.data()[0];
+ dest[1] = sparse_features.data()[0];
+ dest[2] = -sparse_features.data()[0];
+ dest[3] = -sparse_features.data()[0];
+ return true;
+ }
+
+ private:
+ std::vector<float> storage_;
+};
+
TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesMiddle) {
std::vector<Token> tokens{Token("Hělló", 0, 5),
Token("fěěbař@google.com", 6, 23),
@@ -119,9 +147,11 @@
}
TEST(FeatureProcessorTest, KeepLineWithClickFirst) {
- FeatureProcessorOptions options;
- options.set_only_use_line_with_click(true);
- TestingFeatureProcessor feature_processor(options);
+ FeatureProcessorOptionsT options;
+ options.only_use_line_with_click = true;
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()));
const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
const CodepointSpan span = {0, 5};
@@ -141,9 +171,11 @@
}
TEST(FeatureProcessorTest, KeepLineWithClickSecond) {
- FeatureProcessorOptions options;
- options.set_only_use_line_with_click(true);
- TestingFeatureProcessor feature_processor(options);
+ FeatureProcessorOptionsT options;
+ options.only_use_line_with_click = true;
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()));
const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
const CodepointSpan span = {18, 22};
@@ -163,9 +195,11 @@
}
TEST(FeatureProcessorTest, KeepLineWithClickThird) {
- FeatureProcessorOptions options;
- options.set_only_use_line_with_click(true);
- TestingFeatureProcessor feature_processor(options);
+ FeatureProcessorOptionsT options;
+ options.only_use_line_with_click = true;
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()));
const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
const CodepointSpan span = {24, 33};
@@ -185,9 +219,11 @@
}
TEST(FeatureProcessorTest, KeepLineWithClickSecondWithPipe) {
- FeatureProcessorOptions options;
- options.set_only_use_line_with_click(true);
- TestingFeatureProcessor feature_processor(options);
+ FeatureProcessorOptionsT options;
+ options.only_use_line_with_click = true;
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()));
const std::string context = "Fiřst Lině|Sěcond Lině\nThiřd Lině";
const CodepointSpan span = {18, 22};
@@ -207,9 +243,11 @@
}
TEST(FeatureProcessorTest, KeepLineWithCrosslineClick) {
- FeatureProcessorOptions options;
- options.set_only_use_line_with_click(true);
- TestingFeatureProcessor feature_processor(options);
+ FeatureProcessorOptionsT options;
+ options.only_use_line_with_click = true;
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()));
const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
const CodepointSpan span = {5, 23};
@@ -231,18 +269,21 @@
}
TEST(FeatureProcessorTest, SpanToLabel) {
- FeatureProcessorOptions options;
- options.set_context_size(1);
- options.set_max_selection_span(1);
- options.set_snap_label_span_boundaries_to_containing_tokens(false);
+ FeatureProcessorOptionsT options;
+ options.context_size = 1;
+ options.max_selection_span = 1;
+ options.snap_label_span_boundaries_to_containing_tokens = false;
- TokenizationCodepointRange* config =
- options.add_tokenization_codepoint_config();
- config->set_start(32);
- config->set_end(33);
- config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
+ options.tokenization_codepoint_config.emplace_back(
+ new TokenizationCodepointRangeT());
+ auto& config = options.tokenization_codepoint_config.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
- TestingFeatureProcessor feature_processor(options);
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()));
std::vector<Token> tokens = feature_processor.Tokenize("one, two, three");
ASSERT_EQ(3, tokens.size());
int label;
@@ -256,8 +297,11 @@
EXPECT_EQ(0, token_span.second);
// Reconfigure with snapping enabled.
- options.set_snap_label_span_boundaries_to_containing_tokens(true);
- TestingFeatureProcessor feature_processor2(options);
+ options.snap_label_span_boundaries_to_containing_tokens = true;
+ flatbuffers::DetachedBuffer options2_fb =
+ PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor2(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()));
int label2;
ASSERT_TRUE(feature_processor2.SpanToLabel({5, 8}, tokens, &label2));
EXPECT_EQ(label, label2);
@@ -273,9 +317,12 @@
EXPECT_EQ(kInvalidLabel, label2);
// Multiple tokens.
- options.set_context_size(2);
- options.set_max_selection_span(2);
- TestingFeatureProcessor feature_processor3(options);
+ options.context_size = 2;
+ options.max_selection_span = 2;
+ flatbuffers::DetachedBuffer options3_fb =
+ PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor3(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()));
tokens = feature_processor3.Tokenize("zero, one, two, three, four");
ASSERT_TRUE(feature_processor3.SpanToLabel({6, 15}, tokens, &label2));
EXPECT_NE(kInvalidLabel, label2);
@@ -293,18 +340,21 @@
}
TEST(FeatureProcessorTest, SpanToLabelIgnoresPunctuation) {
- FeatureProcessorOptions options;
- options.set_context_size(1);
- options.set_max_selection_span(1);
- options.set_snap_label_span_boundaries_to_containing_tokens(false);
+ FeatureProcessorOptionsT options;
+ options.context_size = 1;
+ options.max_selection_span = 1;
+ options.snap_label_span_boundaries_to_containing_tokens = false;
- TokenizationCodepointRange* config =
- options.add_tokenization_codepoint_config();
- config->set_start(32);
- config->set_end(33);
- config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
+ options.tokenization_codepoint_config.emplace_back(
+ new TokenizationCodepointRangeT());
+ auto& config = options.tokenization_codepoint_config.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
- TestingFeatureProcessor feature_processor(options);
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()));
std::vector<Token> tokens = feature_processor.Tokenize("one, two, three");
ASSERT_EQ(3, tokens.size());
int label;
@@ -318,8 +368,11 @@
EXPECT_EQ(0, token_span.second);
// Reconfigure with snapping enabled.
- options.set_snap_label_span_boundaries_to_containing_tokens(true);
- TestingFeatureProcessor feature_processor2(options);
+ options.snap_label_span_boundaries_to_containing_tokens = true;
+ flatbuffers::DetachedBuffer options2_fb =
+ PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor2(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()));
int label2;
ASSERT_TRUE(feature_processor2.SpanToLabel({5, 8}, tokens, &label2));
EXPECT_EQ(label, label2);
@@ -335,9 +388,12 @@
EXPECT_EQ(kInvalidLabel, label2);
// Multiple tokens.
- options.set_context_size(2);
- options.set_max_selection_span(2);
- TestingFeatureProcessor feature_processor3(options);
+ options.context_size = 2;
+ options.max_selection_span = 2;
+ flatbuffers::DetachedBuffer options3_fb =
+ PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor3(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()));
tokens = feature_processor3.Tokenize("zero, one, two, three, four");
ASSERT_TRUE(feature_processor3.SpanToLabel({6, 15}, tokens, &label2));
EXPECT_NE(kInvalidLabel, label2);
@@ -420,39 +476,64 @@
}
TEST(FeatureProcessorTest, SupportedCodepointsRatio) {
- FeatureProcessorOptions options;
- options.set_context_size(2);
- options.set_max_selection_span(2);
- options.set_snap_label_span_boundaries_to_containing_tokens(false);
+ FeatureProcessorOptionsT options;
+ options.context_size = 2;
+ options.max_selection_span = 2;
+ options.snap_label_span_boundaries_to_containing_tokens = false;
+ options.feature_version = 2;
+ options.embedding_size = 4;
+ options.bounds_sensitive_features.reset(
+ new FeatureProcessorOptions_::BoundsSensitiveFeaturesT());
+ options.bounds_sensitive_features->enabled = true;
+ options.bounds_sensitive_features->num_tokens_before = 5;
+ options.bounds_sensitive_features->num_tokens_inside_left = 3;
+ options.bounds_sensitive_features->num_tokens_inside_right = 3;
+ options.bounds_sensitive_features->num_tokens_after = 5;
+ options.bounds_sensitive_features->include_inside_bag = true;
+ options.bounds_sensitive_features->include_inside_length = true;
- TokenizationCodepointRange* config =
- options.add_tokenization_codepoint_config();
- config->set_start(32);
- config->set_end(33);
- config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
+ options.tokenization_codepoint_config.emplace_back(
+ new TokenizationCodepointRangeT());
+ auto& config = options.tokenization_codepoint_config.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
- FeatureProcessorOptions::CodepointRange* range;
- range = options.add_supported_codepoint_ranges();
- range->set_start(0);
- range->set_end(128);
+ {
+ options.supported_codepoint_ranges.emplace_back(
+ new FeatureProcessorOptions_::CodepointRangeT());
+ auto& range = options.supported_codepoint_ranges.back();
+ range->start = 0;
+ range->end = 128;
+ }
- range = options.add_supported_codepoint_ranges();
- range->set_start(10000);
- range->set_end(10001);
+ {
+ options.supported_codepoint_ranges.emplace_back(
+ new FeatureProcessorOptions_::CodepointRangeT());
+ auto& range = options.supported_codepoint_ranges.back();
+ range->start = 10000;
+ range->end = 10001;
+ }
- range = options.add_supported_codepoint_ranges();
- range->set_start(20000);
- range->set_end(30000);
+ {
+ options.supported_codepoint_ranges.emplace_back(
+ new FeatureProcessorOptions_::CodepointRangeT());
+ auto& range = options.supported_codepoint_ranges.back();
+ range->start = 20000;
+ range->end = 30000;
+ }
- TestingFeatureProcessor feature_processor(options);
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()));
EXPECT_THAT(feature_processor.SupportedCodepointsRatio(
- 1, feature_processor.Tokenize("aaa bbb ccc")),
+ {0, 3}, feature_processor.Tokenize("aaa bbb ccc")),
FloatEq(1.0));
EXPECT_THAT(feature_processor.SupportedCodepointsRatio(
- 1, feature_processor.Tokenize("aaa bbb ěěě")),
+ {0, 3}, feature_processor.Tokenize("aaa bbb ěěě")),
FloatEq(2.0 / 3));
EXPECT_THAT(feature_processor.SupportedCodepointsRatio(
- 1, feature_processor.Tokenize("ěěě řřř ěěě")),
+ {0, 3}, feature_processor.Tokenize("ěěě řřř ěěě")),
FloatEq(0.0));
EXPECT_FALSE(feature_processor.IsCodepointInRanges(
-1, feature_processor.supported_codepoint_ranges_));
@@ -473,32 +554,39 @@
EXPECT_TRUE(feature_processor.IsCodepointInRanges(
25000, feature_processor.supported_codepoint_ranges_));
- std::vector<Token> tokens;
- int click_pos;
- std::vector<float> extra_features;
std::unique_ptr<CachedFeatures> cached_features;
- auto feature_fn = [](const std::vector<int>& sparse_features,
- const std::vector<float>& dense_features,
- float* embedding) { return true; };
+ FakeEmbeddingExecutor embedding_executor;
- options.set_min_supported_codepoint_ratio(0.0);
- TestingFeatureProcessor feature_processor2(options);
- EXPECT_TRUE(feature_processor2.ExtractFeatures("ěěě řřř eee", {4, 7}, {0, 0},
- feature_fn, 2, &tokens,
- &click_pos, &cached_features));
+ const std::vector<Token> tokens = {Token("ěěě", 0, 3), Token("řřř", 4, 7),
+ Token("eee", 8, 11)};
- options.set_min_supported_codepoint_ratio(0.2);
- TestingFeatureProcessor feature_processor3(options);
- EXPECT_TRUE(feature_processor3.ExtractFeatures("ěěě řřř eee", {4, 7}, {0, 0},
- feature_fn, 2, &tokens,
- &click_pos, &cached_features));
+ options.min_supported_codepoint_ratio = 0.0;
+ flatbuffers::DetachedBuffer options2_fb =
+ PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor2(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()));
+ EXPECT_TRUE(feature_processor2.ExtractFeatures(
+ tokens, {0, 3}, &embedding_executor,
+ /*feature_vector_size=*/4, &cached_features));
- options.set_min_supported_codepoint_ratio(0.5);
- TestingFeatureProcessor feature_processor4(options);
+ options.min_supported_codepoint_ratio = 0.2;
+ flatbuffers::DetachedBuffer options3_fb =
+ PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor3(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()));
+ EXPECT_TRUE(feature_processor3.ExtractFeatures(
+ tokens, {0, 3}, &embedding_executor,
+ /*feature_vector_size=*/4, &cached_features));
+
+ options.min_supported_codepoint_ratio = 0.5;
+ flatbuffers::DetachedBuffer options4_fb =
+ PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor4(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options4_fb.data()));
EXPECT_FALSE(feature_processor4.ExtractFeatures(
- "ěěě řřř eee", {4, 7}, {0, 0}, feature_fn, 2, &tokens, &click_pos,
- &cached_features));
+ tokens, {0, 3}, &embedding_executor,
+ /*feature_vector_size=*/4, &cached_features));
}
TEST(FeatureProcessorTest, StripUnusedTokensWithNoRelativeClick) {
@@ -613,12 +701,45 @@
EXPECT_EQ(click_index, 5);
}
-TEST(FeatureProcessorTest, ICUTokenize) {
- FeatureProcessorOptions options;
- options.set_tokenization_type(
- libtextclassifier::FeatureProcessorOptions::ICU);
+TEST(FeatureProcessorTest, InternalTokenizeOnScriptChange) {
+ FeatureProcessorOptionsT options;
+ options.tokenization_codepoint_config.emplace_back(
+ new TokenizationCodepointRangeT());
+ {
+ auto& config = options.tokenization_codepoint_config.back();
+ config->start = 0;
+ config->end = 256;
+ config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
+ config->script_id = 1;
+ }
+ options.tokenize_on_script_change = false;
- TestingFeatureProcessor feature_processor(options);
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()));
+
+ EXPECT_EQ(feature_processor.Tokenize("앨라배마123웹사이트"),
+ std::vector<Token>({Token("앨라배마123웹사이트", 0, 11)}));
+
+ options.tokenize_on_script_change = true;
+ flatbuffers::DetachedBuffer options_fb2 =
+ PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor2(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb2.data()));
+
+ EXPECT_EQ(feature_processor2.Tokenize("앨라배마123웹사이트"),
+ std::vector<Token>({Token("앨라배마", 0, 4), Token("123", 4, 7),
+ Token("웹사이트", 7, 11)}));
+}
+
+#ifdef LIBTEXTCLASSIFIER_TEST_ICU
+TEST(FeatureProcessorTest, ICUTokenize) {
+ FeatureProcessorOptionsT options;
+ options.tokenization_type = FeatureProcessorOptions_::TokenizationType_ICU;
+
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()));
std::vector<Token> tokens = feature_processor.Tokenize("พระบาทสมเด็จพระปรมิ");
ASSERT_EQ(tokens,
// clang-format off
@@ -629,14 +750,17 @@
Token("มิ", 17, 19)}));
// clang-format on
}
+#endif
+#ifdef LIBTEXTCLASSIFIER_TEST_ICU
TEST(FeatureProcessorTest, ICUTokenizeWithWhitespaces) {
- FeatureProcessorOptions options;
- options.set_tokenization_type(
- libtextclassifier::FeatureProcessorOptions::ICU);
- options.set_icu_preserve_whitespace_tokens(true);
+ FeatureProcessorOptionsT options;
+ options.tokenization_type = FeatureProcessorOptions_::TokenizationType_ICU;
+ options.icu_preserve_whitespace_tokens = true;
- TestingFeatureProcessor feature_processor(options);
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()));
std::vector<Token> tokens =
feature_processor.Tokenize("พระบาท สมเด็จ พระ ปร มิ");
ASSERT_EQ(tokens,
@@ -652,36 +776,55 @@
Token("มิ", 21, 23)}));
// clang-format on
}
+#endif
+#ifdef LIBTEXTCLASSIFIER_TEST_ICU
TEST(FeatureProcessorTest, MixedTokenize) {
- FeatureProcessorOptions options;
- options.set_tokenization_type(
- libtextclassifier::FeatureProcessorOptions::MIXED);
+ FeatureProcessorOptionsT options;
+ options.tokenization_type = FeatureProcessorOptions_::TokenizationType_MIXED;
- TokenizationCodepointRange* config =
- options.add_tokenization_codepoint_config();
- config->set_start(32);
- config->set_end(33);
- config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
+ options.tokenization_codepoint_config.emplace_back(
+ new TokenizationCodepointRangeT());
+ auto& config = options.tokenization_codepoint_config.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
- FeatureProcessorOptions::CodepointRange* range;
- range = options.add_internal_tokenizer_codepoint_ranges();
- range->set_start(0);
- range->set_end(128);
+ {
+ options.internal_tokenizer_codepoint_ranges.emplace_back(
+ new FeatureProcessorOptions_::CodepointRangeT());
+ auto& range = options.internal_tokenizer_codepoint_ranges.back();
+ range->start = 0;
+ range->end = 128;
+ }
- range = options.add_internal_tokenizer_codepoint_ranges();
- range->set_start(128);
- range->set_end(256);
+ {
+ options.internal_tokenizer_codepoint_ranges.emplace_back(
+ new FeatureProcessorOptions_::CodepointRangeT());
+ auto& range = options.internal_tokenizer_codepoint_ranges.back();
+ range->start = 128;
+ range->end = 256;
+ }
- range = options.add_internal_tokenizer_codepoint_ranges();
- range->set_start(256);
- range->set_end(384);
+ {
+ options.internal_tokenizer_codepoint_ranges.emplace_back(
+ new FeatureProcessorOptions_::CodepointRangeT());
+ auto& range = options.internal_tokenizer_codepoint_ranges.back();
+ range->start = 256;
+ range->end = 384;
+ }
- range = options.add_internal_tokenizer_codepoint_ranges();
- range->set_start(384);
- range->set_end(592);
+ {
+ options.internal_tokenizer_codepoint_ranges.emplace_back(
+ new FeatureProcessorOptions_::CodepointRangeT());
+ auto& range = options.internal_tokenizer_codepoint_ranges.back();
+ range->start = 384;
+ range->end = 592;
+ }
- TestingFeatureProcessor feature_processor(options);
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()));
std::vector<Token> tokens = feature_processor.Tokenize(
"こんにちはJapanese-ląnguagę text 世界 http://www.google.com/");
ASSERT_EQ(tokens,
@@ -693,15 +836,18 @@
Token("http://www.google.com/", 31, 53)}));
// clang-format on
}
+#endif
TEST(FeatureProcessorTest, IgnoredSpanBoundaryCodepoints) {
- FeatureProcessorOptions options;
- options.add_ignored_span_boundary_codepoints('.');
- options.add_ignored_span_boundary_codepoints(',');
- options.add_ignored_span_boundary_codepoints('[');
- options.add_ignored_span_boundary_codepoints(']');
+ FeatureProcessorOptionsT options;
+ options.ignored_span_boundary_codepoints.push_back('.');
+ options.ignored_span_boundary_codepoints.push_back(',');
+ options.ignored_span_boundary_codepoints.push_back('[');
+ options.ignored_span_boundary_codepoints.push_back(']');
- TestingFeatureProcessor feature_processor(options);
+ flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
+ TestingFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()));
const std::string text1_utf8 = "ěščř";
const UnicodeText text1 = UTF8ToUnicodeText(text1_utf8, /*do_copy=*/false);
@@ -834,4 +980,4 @@
}
} // namespace
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
diff --git a/jni.lds b/jni.lds
index 75d5bc5..171cc0e 100644
--- a/jni.lds
+++ b/jni.lds
@@ -1,7 +1,7 @@
-{
- # Export symbols that correspond to our JNIEXPORTed functions.
+VERS_1.0 {
+ # Export JNI symbols.
global:
- Java_android_view_textclassifier_*;
+ Java_*;
# Hide everything else.
local:
diff --git a/lang_id/custom-tokenizer.cc b/lang_id/custom-tokenizer.cc
deleted file mode 100644
index 7e30cc7..0000000
--- a/lang_id/custom-tokenizer.cc
+++ /dev/null
@@ -1,109 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "lang_id/custom-tokenizer.h"
-
-#include <ctype.h>
-
-#include <string>
-
-#include "util/strings/utf8.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-namespace lang_id {
-
-namespace {
-inline bool IsTokenSeparator(int num_bytes, const char *curr) {
- if (num_bytes != 1) {
- return false;
- }
- return !isalpha(*curr);
-}
-} // namespace
-
-const char *GetSafeEndOfString(const char *data, size_t size) {
- const char *const hard_end = data + size;
- const char *curr = data;
- while (curr < hard_end) {
- int num_bytes = GetNumBytesForUTF8Char(curr);
- if (num_bytes == 0) {
- break;
- }
- const char *new_curr = curr + num_bytes;
- if (new_curr > hard_end) {
- return curr;
- }
- curr = new_curr;
- }
- return curr;
-}
-
-void TokenizeTextForLangId(const std::string &text, LightSentence *sentence) {
- const char *const start = text.data();
- const char *curr = start;
- const char *end = GetSafeEndOfString(start, text.size());
-
- // Corner case: empty safe part of the text.
- if (curr >= end) {
- return;
- }
-
- // Number of bytes for UTF8 character starting at *curr. Note: the loop below
- // is guaranteed to terminate because in each iteration, we move curr by at
- // least num_bytes, and num_bytes is guaranteed to be > 0.
- int num_bytes = GetNumBytesForNonZeroUTF8Char(curr);
- while (curr < end) {
- // Jump over consecutive token separators.
- while (IsTokenSeparator(num_bytes, curr)) {
- curr += num_bytes;
- if (curr >= end) {
- return;
- }
- num_bytes = GetNumBytesForNonZeroUTF8Char(curr);
- }
-
- // If control reaches this point, we are at beginning of a non-empty token.
- std::string *word = sentence->add_word();
-
- // Add special token-start character.
- word->push_back('^');
-
- // Add UTF8 characters to word, until we hit the end of the safe text or a
- // token separator.
- while (true) {
- word->append(curr, num_bytes);
- curr += num_bytes;
- if (curr >= end) {
- break;
- }
- num_bytes = GetNumBytesForNonZeroUTF8Char(curr);
- if (IsTokenSeparator(num_bytes, curr)) {
- curr += num_bytes;
- num_bytes = GetNumBytesForNonZeroUTF8Char(curr);
- break;
- }
- }
- word->push_back('$');
-
- // Note: we intentionally do not token.set_start()/end(), as those fields
- // are not used by the langid model.
- }
-}
-
-} // namespace lang_id
-} // namespace nlp_core
-} // namespace libtextclassifier
diff --git a/lang_id/custom-tokenizer.h b/lang_id/custom-tokenizer.h
deleted file mode 100644
index c9c291c..0000000
--- a/lang_id/custom-tokenizer.h
+++ /dev/null
@@ -1,62 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_LANG_ID_CUSTOM_TOKENIZER_H_
-#define LIBTEXTCLASSIFIER_LANG_ID_CUSTOM_TOKENIZER_H_
-
-#include <cstddef>
-#include <string>
-
-#include "lang_id/light-sentence.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-namespace lang_id {
-
-// Perform custom tokenization of text. Customized for the language
-// identification project. Currently (Sep 15, 2016) we tokenize on space,
-// newline, and tab, ignore all empty tokens, and (for each of the remaining
-// tokens) prepend "^" (special token begin marker) and append "$" (special
-// token end marker).
-//
-// Tokens are stored into the words of the LightSentence *sentence.
-void TokenizeTextForLangId(const std::string &text, LightSentence *sentence);
-
-// Returns a pointer "end" inside [data, data + size) such that the prefix from
-// [data, end) is the largest one that does not contain '\0' and offers the
-// following guarantee: if one starts with
-//
-// curr = text.data()
-//
-// and keeps executing
-//
-// curr += utils::GetNumBytesForNonZeroUTF8Char(curr)
-//
-// one would eventually reach curr == end (the pointer returned by this
-// function) without accessing data outside the std::string. This guards
-// against scenarios like a broken UTF-8 string which has only e.g., the first 2
-// bytes from a 3-byte UTF8 sequence.
-const char *GetSafeEndOfString(const char *data, size_t size);
-
-static inline const char *GetSafeEndOfString(const std::string &text) {
- return GetSafeEndOfString(text.data(), text.size());
-}
-
-} // namespace lang_id
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_LANG_ID_CUSTOM_TOKENIZER_H_
diff --git a/lang_id/lang-id-brain-interface.h b/lang_id/lang-id-brain-interface.h
deleted file mode 100644
index ce79497..0000000
--- a/lang_id/lang-id-brain-interface.h
+++ /dev/null
@@ -1,88 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_BRAIN_INTERFACE_H_
-#define LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_BRAIN_INTERFACE_H_
-
-#include <string>
-#include <vector>
-
-#include "common/embedding-feature-extractor.h"
-#include "common/feature-extractor.h"
-#include "common/task-context.h"
-#include "common/workspace.h"
-#include "lang_id/light-sentence-features.h"
-#include "lang_id/light-sentence.h"
-#include "util/base/macros.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-namespace lang_id {
-
-// Specialization of EmbeddingFeatureExtractor that extracts from LightSentence.
-class LangIdEmbeddingFeatureExtractor
- : public EmbeddingFeatureExtractor<LightSentenceExtractor, LightSentence> {
- public:
- LangIdEmbeddingFeatureExtractor() {}
- const std::string ArgPrefix() const override { return "language_identifier"; }
-
- TC_DISALLOW_COPY_AND_ASSIGN(LangIdEmbeddingFeatureExtractor);
-};
-
-// Handles sentence -> numeric_features and numeric_prediction -> language
-// conversions.
-class LangIdBrainInterface {
- public:
- LangIdBrainInterface() {}
-
- // Initializes resources and parameters.
- bool Init(TaskContext *context) {
- if (!feature_extractor_.Init(context)) {
- return false;
- }
- feature_extractor_.RequestWorkspaces(&workspace_registry_);
- return true;
- }
-
- // Extract features from sentence. On return, FeatureVector features[i]
- // contains the features for the embedding space #i.
- void GetFeatures(LightSentence *sentence,
- std::vector<FeatureVector> *features) const {
- WorkspaceSet workspace;
- workspace.Reset(workspace_registry_);
- feature_extractor_.Preprocess(&workspace, sentence);
- return feature_extractor_.ExtractFeatures(workspace, *sentence, features);
- }
-
- int NumEmbeddings() const {
- return feature_extractor_.NumEmbeddings();
- }
-
- private:
- // Typed feature extractor for embeddings.
- LangIdEmbeddingFeatureExtractor feature_extractor_;
-
- // The registry of shared workspaces in the feature extractor.
- WorkspaceRegistry workspace_registry_;
-
- TC_DISALLOW_COPY_AND_ASSIGN(LangIdBrainInterface);
-};
-
-} // namespace lang_id
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_BRAIN_INTERFACE_H_
diff --git a/lang_id/lang-id.cc b/lang_id/lang-id.cc
deleted file mode 100644
index 8383d33..0000000
--- a/lang_id/lang-id.cc
+++ /dev/null
@@ -1,402 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "lang_id/lang-id.h"
-
-#include <stdio.h>
-
-#include <algorithm>
-#include <limits>
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "common/algorithm.h"
-#include "common/embedding-network-params-from-proto.h"
-#include "common/embedding-network.pb.h"
-#include "common/embedding-network.h"
-#include "common/feature-extractor.h"
-#include "common/file-utils.h"
-#include "common/list-of-strings.pb.h"
-#include "common/memory_image/in-memory-model-data.h"
-#include "common/mmap.h"
-#include "common/softmax.h"
-#include "common/task-context.h"
-#include "lang_id/custom-tokenizer.h"
-#include "lang_id/lang-id-brain-interface.h"
-#include "lang_id/language-identifier-features.h"
-#include "lang_id/light-sentence-features.h"
-#include "lang_id/light-sentence.h"
-#include "lang_id/relevant-script-feature.h"
-#include "util/base/logging.h"
-#include "util/base/macros.h"
-
-using ::libtextclassifier::nlp_core::file_utils::ParseProtoFromMemory;
-
-namespace libtextclassifier {
-namespace nlp_core {
-namespace lang_id {
-
-namespace {
-// Default value for the probability threshold; see comments for
-// LangId::SetProbabilityThreshold().
-static const float kDefaultProbabilityThreshold = 0.50;
-
-// Default value for min text size below which our model can't provide a
-// meaningful prediction.
-static const int kDefaultMinTextSizeInBytes = 20;
-
-// Initial value for the default language for LangId::FindLanguage(). The
-// default language can be changed (for an individual LangId object) using
-// LangId::SetDefaultLanguage().
-static const char kInitialDefaultLanguage[] = "";
-
-// Returns total number of bytes of the words from sentence, without the ^
-// (start-of-word) and $ (end-of-word) markers. Note: "real text" means that
-// this ignores whitespace and punctuation characters from the original text.
-int GetRealTextSize(const LightSentence &sentence) {
- int total = 0;
- for (int i = 0; i < sentence.num_words(); ++i) {
- TC_DCHECK(!sentence.word(i).empty());
- TC_DCHECK_EQ('^', sentence.word(i).front());
- TC_DCHECK_EQ('$', sentence.word(i).back());
- total += sentence.word(i).size() - 2;
- }
- return total;
-}
-
-} // namespace
-
-// Class that performs all work behind LangId.
-class LangIdImpl {
- public:
- explicit LangIdImpl(const std::string &filename) {
- // Using mmap as a fast way to read the model bytes.
- ScopedMmap scoped_mmap(filename);
- MmapHandle mmap_handle = scoped_mmap.handle();
- if (!mmap_handle.ok()) {
- TC_LOG(ERROR) << "Unable to read model bytes.";
- return;
- }
-
- Initialize(mmap_handle.to_stringpiece());
- }
-
- explicit LangIdImpl(int fd) {
- // Using mmap as a fast way to read the model bytes.
- ScopedMmap scoped_mmap(fd);
- MmapHandle mmap_handle = scoped_mmap.handle();
- if (!mmap_handle.ok()) {
- TC_LOG(ERROR) << "Unable to read model bytes.";
- return;
- }
-
- Initialize(mmap_handle.to_stringpiece());
- }
-
- LangIdImpl(const char *ptr, size_t length) {
- Initialize(StringPiece(ptr, length));
- }
-
- void Initialize(StringPiece model_bytes) {
- // Will set valid_ to true only on successful initialization.
- valid_ = false;
-
- // Make sure all relevant features are registered:
- ContinuousBagOfNgramsFunction::RegisterClass();
- RelevantScriptFeature::RegisterClass();
-
- // NOTE(salcianu): code below relies on the fact that the current features
- // do not rely on data from a TaskInput. Otherwise, one would have to use
- // the more complex model registration mechanism, which requires more code.
- InMemoryModelData model_data(model_bytes);
- TaskContext context;
- if (!model_data.GetTaskSpec(context.mutable_spec())) {
- TC_LOG(ERROR) << "Unable to get model TaskSpec";
- return;
- }
-
- if (!ParseNetworkParams(model_data, &context)) {
- return;
- }
- if (!ParseListOfKnownLanguages(model_data, &context)) {
- return;
- }
-
- network_.reset(new EmbeddingNetwork(network_params_.get()));
- if (!network_->is_valid()) {
- return;
- }
-
- probability_threshold_ =
- context.Get("reliability_thresh", kDefaultProbabilityThreshold);
- min_text_size_in_bytes_ =
- context.Get("min_text_size_in_bytes", kDefaultMinTextSizeInBytes);
- version_ = context.Get("version", 0);
-
- if (!lang_id_brain_interface_.Init(&context)) {
- return;
- }
- valid_ = true;
- }
-
- void SetProbabilityThreshold(float threshold) {
- probability_threshold_ = threshold;
- }
-
- void SetDefaultLanguage(const std::string &lang) { default_language_ = lang; }
-
- std::string FindLanguage(const std::string &text) const {
- std::vector<float> scores = ScoreLanguages(text);
- if (scores.empty()) {
- return default_language_;
- }
-
- // Softmax label with max score.
- int label = GetArgMax(scores);
- float probability = scores[label];
- if (probability < probability_threshold_) {
- return default_language_;
- }
- return GetLanguageForSoftmaxLabel(label);
- }
-
- std::vector<std::pair<std::string, float>> FindLanguages(
- const std::string &text) const {
- std::vector<float> scores = ScoreLanguages(text);
-
- std::vector<std::pair<std::string, float>> result;
- for (int i = 0; i < scores.size(); i++) {
- result.push_back({GetLanguageForSoftmaxLabel(i), scores[i]});
- }
-
- // To avoid crashing clients that always expect at least one predicted
- // language, we promised (see doc for this method) that the result always
- // contains at least one element.
- if (result.empty()) {
- // We use a tiny probability, such that any client that uses a meaningful
- // probability threshold ignores this prediction. We don't use 0.0f, to
- // avoid crashing clients that normalize the probabilities we return here.
- result.push_back({default_language_, 0.001f});
- }
- return result;
- }
-
- std::vector<float> ScoreLanguages(const std::string &text) const {
- if (!is_valid()) {
- return {};
- }
-
- // Create a Sentence storing the input text.
- LightSentence sentence;
- TokenizeTextForLangId(text, &sentence);
-
- if (GetRealTextSize(sentence) < min_text_size_in_bytes_) {
- return {};
- }
-
- // TODO(salcianu): reuse vector<FeatureVector>.
- std::vector<FeatureVector> features(
- lang_id_brain_interface_.NumEmbeddings());
- lang_id_brain_interface_.GetFeatures(&sentence, &features);
-
- // Predict language.
- EmbeddingNetwork::Vector scores;
- network_->ComputeFinalScores(features, &scores);
-
- return ComputeSoftmax(scores);
- }
-
- bool is_valid() const { return valid_; }
-
- int version() const { return version_; }
-
- private:
- // Returns name of the (in-memory) file for the indicated TaskInput from
- // context.
- static std::string GetInMemoryFileNameForTaskInput(
- const std::string &input_name, TaskContext *context) {
- TaskInput *task_input = context->GetInput(input_name);
- if (task_input->part_size() != 1) {
- TC_LOG(ERROR) << "TaskInput " << input_name << " has "
- << task_input->part_size() << " parts";
- return "";
- }
- return task_input->part(0).file_pattern();
- }
-
- bool ParseNetworkParams(const InMemoryModelData &model_data,
- TaskContext *context) {
- const std::string input_name = "language-identifier-network";
- const std::string input_file_name =
- GetInMemoryFileNameForTaskInput(input_name, context);
- if (input_file_name.empty()) {
- TC_LOG(ERROR) << "No input file name for TaskInput " << input_name;
- return false;
- }
- StringPiece bytes = model_data.GetBytesForInputFile(input_file_name);
- if (bytes.data() == nullptr) {
- TC_LOG(ERROR) << "Unable to get bytes for TaskInput " << input_name;
- return false;
- }
- std::unique_ptr<EmbeddingNetworkProto> proto(new EmbeddingNetworkProto());
- if (!ParseProtoFromMemory(bytes, proto.get())) {
- TC_LOG(ERROR) << "Unable to parse EmbeddingNetworkProto";
- return false;
- }
- network_params_.reset(
- new EmbeddingNetworkParamsFromProto(std::move(proto)));
- if (!network_params_->is_valid()) {
- TC_LOG(ERROR) << "EmbeddingNetworkParamsFromProto not valid";
- return false;
- }
- return true;
- }
-
- // Parses dictionary with known languages (i.e., field languages_) from a
- // TaskInput of context. Note: that TaskInput should be a ListOfStrings proto
- // with a single element, the serialized form of a ListOfStrings.
- //
- bool ParseListOfKnownLanguages(const InMemoryModelData &model_data,
- TaskContext *context) {
- const std::string input_name = "language-name-id-map";
- const std::string input_file_name =
- GetInMemoryFileNameForTaskInput(input_name, context);
- if (input_file_name.empty()) {
- TC_LOG(ERROR) << "No input file name for TaskInput " << input_name;
- return false;
- }
- StringPiece bytes = model_data.GetBytesForInputFile(input_file_name);
- if (bytes.data() == nullptr) {
- TC_LOG(ERROR) << "Unable to get bytes for TaskInput " << input_name;
- return false;
- }
- ListOfStrings records;
- if (!ParseProtoFromMemory(bytes, &records)) {
- TC_LOG(ERROR) << "Unable to parse ListOfStrings from TaskInput "
- << input_name;
- return false;
- }
- if (records.element_size() != 1) {
- TC_LOG(ERROR) << "Wrong number of records in TaskInput " << input_name
- << " : " << records.element_size();
- return false;
- }
- if (!ParseProtoFromMemory(std::string(records.element(0)), &languages_)) {
- TC_LOG(ERROR) << "Unable to parse dictionary with known languages";
- return false;
- }
- return true;
- }
-
- // Returns language code for a softmax label. See comments for languages_
- // field. If label is out of range, returns default_language_.
- std::string GetLanguageForSoftmaxLabel(int label) const {
- if ((label >= 0) && (label < languages_.element_size())) {
- return languages_.element(label);
- } else {
- TC_LOG(ERROR) << "Softmax label " << label << " outside range [0, "
- << languages_.element_size() << ")";
- return default_language_;
- }
- }
-
- LangIdBrainInterface lang_id_brain_interface_;
-
- // Parameters for the neural network network_ (see below).
- std::unique_ptr<EmbeddingNetworkParamsFromProto> network_params_;
-
- // Neural network to use for scoring.
- std::unique_ptr<EmbeddingNetwork> network_;
-
- // True if this object is ready to perform language predictions.
- bool valid_;
-
- // Only predictions with a probability (confidence) above this threshold are
- // reported. Otherwise, we report default_language_.
- float probability_threshold_ = kDefaultProbabilityThreshold;
-
- // Min size of the input text for our predictions to be meaningful. Below
- // this threshold, the underlying model may report a wrong language and a high
- // confidence score.
- int min_text_size_in_bytes_ = kDefaultMinTextSizeInBytes;
-
- // Version of the model.
- int version_ = -1;
-
- // Known languages: softmax label i (an integer) means languages_.element(i)
- // (something like "en", "fr", "ru", etc).
- ListOfStrings languages_;
-
- // Language code to return in case of errors.
- std::string default_language_ = kInitialDefaultLanguage;
-
- TC_DISALLOW_COPY_AND_ASSIGN(LangIdImpl);
-};
-
-LangId::LangId(const std::string &filename) : pimpl_(new LangIdImpl(filename)) {
- if (!pimpl_->is_valid()) {
- TC_LOG(ERROR) << "Unable to construct a valid LangId based "
- << "on the data from " << filename
- << "; nothing should crash, but "
- << "accuracy will be bad.";
- }
-}
-
-LangId::LangId(int fd) : pimpl_(new LangIdImpl(fd)) {
- if (!pimpl_->is_valid()) {
- TC_LOG(ERROR) << "Unable to construct a valid LangId based "
- << "on the data from descriptor " << fd
- << "; nothing should crash, "
- << "but accuracy will be bad.";
- }
-}
-
-LangId::LangId(const char *ptr, size_t length)
- : pimpl_(new LangIdImpl(ptr, length)) {
- if (!pimpl_->is_valid()) {
- TC_LOG(ERROR) << "Unable to construct a valid LangId based "
- << "on the memory region; nothing should crash, "
- << "but accuracy will be bad.";
- }
-}
-
-LangId::~LangId() = default;
-
-void LangId::SetProbabilityThreshold(float threshold) {
- pimpl_->SetProbabilityThreshold(threshold);
-}
-
-void LangId::SetDefaultLanguage(const std::string &lang) {
- pimpl_->SetDefaultLanguage(lang);
-}
-
-std::string LangId::FindLanguage(const std::string &text) const {
- return pimpl_->FindLanguage(text);
-}
-
-std::vector<std::pair<std::string, float>> LangId::FindLanguages(
- const std::string &text) const {
- return pimpl_->FindLanguages(text);
-}
-
-bool LangId::is_valid() const { return pimpl_->is_valid(); }
-
-int LangId::version() const { return pimpl_->version(); }
-
-} // namespace lang_id
-} // namespace nlp_core
-} // namespace libtextclassifier
diff --git a/lang_id/lang-id.h b/lang_id/lang-id.h
deleted file mode 100644
index 7653dde..0000000
--- a/lang_id/lang-id.h
+++ /dev/null
@@ -1,112 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_H_
-#define LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_H_
-
-// Clients who want to perform language identification should use this header.
-//
-// Note for lang id implementors: keep this header as linght as possible. E.g.,
-// any macro defined here (or in a transitively #included file) is a potential
-// name conflict with our clients.
-
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "util/base/macros.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-namespace lang_id {
-
-// Forward-declaration of the class that performs all underlying work.
-class LangIdImpl;
-
-// Class for detecting the language of a document.
-//
-// NOTE: this class is thread-unsafe.
-class LangId {
- public:
- // Constructs a LangId object, loading an EmbeddingNetworkProto model from the
- // indicated file.
- //
- // Note: we don't crash if we detect a problem at construction time (e.g.,
- // file doesn't exist, or its content is corrupted). Instead, we mark the
- // newly-constructed object as invalid; clients can invoke FindLanguage() on
- // an invalid object: nothing crashes, but accuracy will be bad.
- explicit LangId(const std::string &filename);
-
- // Same as above but uses a file descriptor.
- explicit LangId(int fd);
-
- // Same as above but uses already mapped memory region
- explicit LangId(const char *ptr, size_t length);
-
- virtual ~LangId();
-
- // Sets probability threshold for predictions. If our likeliest prediction is
- // below this threshold, we report the default language (see
- // SetDefaultLanguage()). Othewise, we report the likelist language.
- //
- // By default (if this method is not called) we use the probability threshold
- // stored in the model, as the task parameter "reliability_thresh". If that
- // task parameter is not specified, we use 0.5. A client can use this method
- // to get a different precision / recall trade-off. The higher the threshold,
- // the higher the precision and lower the recall rate.
- void SetProbabilityThreshold(float threshold);
-
- // Sets default language to report if errors prevent running the real
- // inference code or if prediction confidence is too small.
- void SetDefaultLanguage(const std::string &lang);
-
- // Returns language code for the most likely language that text is written in.
- // Note: if this LangId object is not valid (see
- // is_valid()), this method returns the default language specified via
- // SetDefaultLanguage() or (if that method was never invoked), the empty
- // std::string.
- std::string FindLanguage(const std::string &text) const;
-
- // Returns a vector of language codes along with the probability for each
- // language. The result contains at least one element. The sum of
- // probabilities may be less than 1.0.
- std::vector<std::pair<std::string, float>> FindLanguages(
- const std::string &text) const;
-
- // Returns true if this object has been correctly initialized and is ready to
- // perform predictions. For more info, see doc for LangId
- // constructor above.
- bool is_valid() const;
-
- // Returns version number for the model.
- int version() const;
-
- private:
- // Returns a vector of probabilities of languages of the text.
- std::vector<float> ScoreLanguages(const std::string &text) const;
-
- // Pimpl ("pointer to implementation") pattern, to hide all internals from our
- // clients.
- std::unique_ptr<LangIdImpl> pimpl_;
-
- TC_DISALLOW_COPY_AND_ASSIGN(LangId);
-};
-
-} // namespace lang_id
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_H_
diff --git a/lang_id/lang-id_test.cc b/lang_id/lang-id_test.cc
deleted file mode 100644
index 2f8aedd..0000000
--- a/lang_id/lang-id_test.cc
+++ /dev/null
@@ -1,152 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "lang_id/lang-id.h"
-
-#include <memory>
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "util/base/logging.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-namespace lang_id {
-
-namespace {
-
-std::string GetModelPath() {
- return TEST_DATA_DIR "langid.model";
-}
-
-// Creates a LangId with default model. Passes ownership to
-// the caller.
-LangId *CreateLanguageDetector() { return new LangId(GetModelPath()); }
-
-} // namespace
-
-TEST(LangIdTest, Normal) {
- std::unique_ptr<LangId> lang_id(CreateLanguageDetector());
-
- EXPECT_EQ("en", lang_id->FindLanguage("This text is written in English."));
- EXPECT_EQ("en",
- lang_id->FindLanguage("This text is written in English. "));
- EXPECT_EQ("en",
- lang_id->FindLanguage(" This text is written in English. "));
- EXPECT_EQ("fr", lang_id->FindLanguage("Vive la France! Vive la France!"));
- EXPECT_EQ("ro", lang_id->FindLanguage("Sunt foarte foarte foarte fericit!"));
-}
-
-// Test that for very small queries, we return the default language and a low
-// confidence score.
-TEST(LangIdTest, SuperSmallQueries) {
- std::unique_ptr<LangId> lang_id(CreateLanguageDetector());
-
- // Use a default language different from any real language: to be sure the
- // result is the default language, not a language that happens to be the
- // default language.
- const std::string kDefaultLanguage = "dflt-lng";
- lang_id->SetDefaultLanguage(kDefaultLanguage);
-
- // Test the simple FindLanguage() method: that method returns a single
- // language.
- EXPECT_EQ(kDefaultLanguage, lang_id->FindLanguage("y"));
- EXPECT_EQ(kDefaultLanguage, lang_id->FindLanguage("j"));
- EXPECT_EQ(kDefaultLanguage, lang_id->FindLanguage("l"));
- EXPECT_EQ(kDefaultLanguage, lang_id->FindLanguage("w"));
- EXPECT_EQ(kDefaultLanguage, lang_id->FindLanguage("z"));
- EXPECT_EQ(kDefaultLanguage, lang_id->FindLanguage("zulu"));
-
- // Test the more complex FindLanguages() method: that method returns a vector
- // of (language, confidence_score) pairs.
- std::vector<std::pair<std::string, float>> languages;
- languages = lang_id->FindLanguages("y");
- EXPECT_EQ(1, languages.size());
- EXPECT_EQ(kDefaultLanguage, languages[0].first);
- EXPECT_GT(0.01f, languages[0].second);
-
- languages = lang_id->FindLanguages("Todoist");
- EXPECT_EQ(1, languages.size());
- EXPECT_EQ(kDefaultLanguage, languages[0].first);
- EXPECT_GT(0.01f, languages[0].second);
-
- // A few tests with a default language that is a real language code.
- const std::string kJapanese = "ja";
- lang_id->SetDefaultLanguage(kJapanese);
- EXPECT_EQ(kJapanese, lang_id->FindLanguage("y"));
- EXPECT_EQ(kJapanese, lang_id->FindLanguage("j"));
- EXPECT_EQ(kJapanese, lang_id->FindLanguage("l"));
- languages = lang_id->FindLanguages("y");
- EXPECT_EQ(1, languages.size());
- EXPECT_EQ(kJapanese, languages[0].first);
- EXPECT_GT(0.01f, languages[0].second);
-
- // Make sure the min text size limit is applied to the number of real
- // characters (e.g., without spaces and punctuation chars, which don't
- // influence language identification).
- const std::string kWhitespaces = " \t \n \t\t\t\n \t";
- const std::string kPunctuation = "... ?!!--- -%%^...-";
- std::string still_small_string = kWhitespaces + "y" + kWhitespaces +
- kPunctuation + kWhitespaces + kPunctuation +
- kPunctuation;
- EXPECT_LE(100, still_small_string.size());
- lang_id->SetDefaultLanguage(kDefaultLanguage);
- EXPECT_EQ(kDefaultLanguage, lang_id->FindLanguage(still_small_string));
- languages = lang_id->FindLanguages(still_small_string);
- EXPECT_EQ(1, languages.size());
- EXPECT_EQ(kDefaultLanguage, languages[0].first);
- EXPECT_GT(0.01f, languages[0].second);
-}
-
-namespace {
-void CheckPredictionForGibberishStrings(const std::string &default_language) {
- static const char *const kGibberish[] = {
- "",
- " ",
- " ",
- " ___ ",
- "123 456 789",
- "><> (-_-) <><",
- nullptr,
- };
-
- std::unique_ptr<LangId> lang_id(CreateLanguageDetector());
- TC_LOG(INFO) << "Default language: " << default_language;
- lang_id->SetDefaultLanguage(default_language);
- for (int i = 0; true; ++i) {
- const char *gibberish = kGibberish[i];
- if (gibberish == nullptr) {
- break;
- }
- const std::string predicted_language = lang_id->FindLanguage(gibberish);
- TC_LOG(INFO) << "Predicted " << predicted_language << " for \"" << gibberish
- << "\"";
- EXPECT_EQ(default_language, predicted_language);
- }
-}
-} // namespace
-
-TEST(LangIdTest, CornerCases) {
- CheckPredictionForGibberishStrings("en");
- CheckPredictionForGibberishStrings("ro");
- CheckPredictionForGibberishStrings("fr");
-}
-
-} // namespace lang_id
-} // namespace nlp_core
-} // namespace libtextclassifier
diff --git a/lang_id/language-identifier-features.cc b/lang_id/language-identifier-features.cc
deleted file mode 100644
index 2e3912e..0000000
--- a/lang_id/language-identifier-features.cc
+++ /dev/null
@@ -1,141 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "lang_id/language-identifier-features.h"
-
-#include <utility>
-#include <vector>
-
-#include "common/feature-extractor.h"
-#include "common/feature-types.h"
-#include "common/task-context.h"
-#include "util/hash/hash.h"
-#include "util/strings/utf8.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-namespace lang_id {
-
-bool ContinuousBagOfNgramsFunction::Setup(TaskContext *context) {
- // Parameters in the feature function descriptor.
- ngram_id_dimension_ = GetIntParameter("id_dim", 10000);
- ngram_size_ = GetIntParameter("size", 3);
-
- counts_.assign(ngram_id_dimension_, 0);
- return true;
-}
-
-bool ContinuousBagOfNgramsFunction::Init(TaskContext *context) {
- set_feature_type(new NumericFeatureType(name(), ngram_id_dimension_));
- return true;
-}
-
-int ContinuousBagOfNgramsFunction::ComputeNgramCounts(
- const LightSentence &sentence) const {
- // Invariant 1: counts_.size() == ngram_id_dimension_. Holds at the end of
- // the constructor. After that, no method changes counts_.size().
- TC_DCHECK_EQ(counts_.size(), ngram_id_dimension_);
-
- // Invariant 2: the vector non_zero_count_indices_ is empty. The vector
- // non_zero_count_indices_ is empty at construction time and gets emptied at
- // the end of each call to Evaluate(). Hence, this invariant holds at the
- // beginning of each run of Evaluate(), where the only call to this code takes
- // place.
- TC_DCHECK(non_zero_count_indices_.empty());
-
- int total_count = 0;
-
- for (int i = 0; i < sentence.num_words(); ++i) {
- const std::string &word = sentence.word(i);
- const char *const word_end = word.data() + word.size();
-
- // Set ngram_start at the start of the current token (word).
- const char *ngram_start = word.data();
-
- // Set ngram_end ngram_size UTF8 characters after ngram_start. Note: each
- // UTF8 character contains between 1 and 4 bytes.
- const char *ngram_end = ngram_start;
- int num_utf8_chars = 0;
- do {
- ngram_end += GetNumBytesForNonZeroUTF8Char(ngram_end);
- num_utf8_chars++;
- } while ((num_utf8_chars < ngram_size_) && (ngram_end < word_end));
-
- if (num_utf8_chars < ngram_size_) {
- // Current token is so small, it does not contain a single ngram of
- // ngram_size UTF8 characters. Not much we can do in this case ...
- continue;
- }
-
- // At this point, [ngram_start, ngram_end) is the first ngram of ngram_size
- // UTF8 characters from current token.
- while (true) {
- // Compute ngram_id: hash(ngram) % ngram_id_dimension
- int ngram_id =
- (Hash32WithDefaultSeed(ngram_start, ngram_end - ngram_start) %
- ngram_id_dimension_);
-
- // Use a reference to the actual count, such that we can both test whether
- // the count was 0 and increment it without perfoming two lookups.
- //
- // Due to the way we compute ngram_id, 0 <= ngram_id < ngram_id_dimension.
- // Hence, by Invariant 1 (above), the access counts_[ngram_id] is safe.
- int &ref_to_count_for_ngram = counts_[ngram_id];
- if (ref_to_count_for_ngram == 0) {
- non_zero_count_indices_.push_back(ngram_id);
- }
- ref_to_count_for_ngram++;
- total_count++;
- if (ngram_end >= word_end) {
- break;
- }
-
- // Advance both ngram_start and ngram_end by one UTF8 character. This
- // way, the number of UTF8 characters between them remains constant
- // (ngram_size).
- ngram_start += GetNumBytesForNonZeroUTF8Char(ngram_start);
- ngram_end += GetNumBytesForNonZeroUTF8Char(ngram_end);
- }
- } // end of loop over tokens.
-
- return total_count;
-}
-
-void ContinuousBagOfNgramsFunction::Evaluate(const WorkspaceSet &workspaces,
- const LightSentence &sentence,
- FeatureVector *result) const {
- // Find the char ngram counts.
- int total_count = ComputeNgramCounts(sentence);
-
- // Populate the feature vector.
- const float norm = static_cast<float>(total_count);
-
- for (int ngram_id : non_zero_count_indices_) {
- const float weight = counts_[ngram_id] / norm;
- FloatFeatureValue value(ngram_id, weight);
- result->add(feature_type(), value.discrete_value);
-
- // Clear up counts_, for the next invocation of Evaluate().
- counts_[ngram_id] = 0;
- }
-
- // Clear up non_zero_count_indices_, for the next invocation of Evaluate().
- non_zero_count_indices_.clear();
-}
-
-} // namespace lang_id
-} // namespace nlp_core
-} // namespace libtextclassifier
diff --git a/lang_id/language-identifier-features.h b/lang_id/language-identifier-features.h
deleted file mode 100644
index a4e3b3d..0000000
--- a/lang_id/language-identifier-features.h
+++ /dev/null
@@ -1,79 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_LANG_ID_LANGUAGE_IDENTIFIER_FEATURES_H_
-#define LIBTEXTCLASSIFIER_LANG_ID_LANGUAGE_IDENTIFIER_FEATURES_H_
-
-#include <string>
-
-#include "common/feature-extractor.h"
-#include "common/task-context.h"
-#include "common/workspace.h"
-#include "lang_id/light-sentence-features.h"
-#include "lang_id/light-sentence.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-namespace lang_id {
-
-// Class for computing continuous char ngram features.
-//
-// Feature function descriptor parameters:
-// id_dim(int, 10000):
-// The integer id of each char ngram is computed as follows:
-// Hash32WithDefaultSeed(char ngram) % id_dim.
-// size(int, 3):
-// Only ngrams of this size will be extracted.
-//
-// NOTE: this class is not thread-safe. TODO(salcianu): make it thread-safe.
-class ContinuousBagOfNgramsFunction : public LightSentenceFeature {
- public:
- bool Setup(TaskContext *context) override;
- bool Init(TaskContext *context) override;
-
- // Appends the features computed from the sentence to the feature vector.
- void Evaluate(const WorkspaceSet &workspaces, const LightSentence &sentence,
- FeatureVector *result) const override;
-
- TC_DEFINE_REGISTRATION_METHOD("continuous-bag-of-ngrams",
- ContinuousBagOfNgramsFunction);
-
- private:
- // Auxiliary for Evaluate(). Fills counts_ and non_zero_count_indices_ (see
- // below), and returns the total ngram count.
- int ComputeNgramCounts(const LightSentence &sentence) const;
-
- // counts_[i] is the count of all ngrams with id i. Work data for Evaluate().
- // NOTE: we declare this vector as a field, such that its underlying capacity
- // stays allocated in between calls to Evaluate().
- mutable std::vector<int> counts_;
-
- // Indices of non-zero elements of counts_. See comments for counts_.
- mutable std::vector<int> non_zero_count_indices_;
-
- // The integer id of each char ngram is computed as follows:
- // Hash32WithDefaultSeed(char_ngram) % ngram_id_dimension_.
- int ngram_id_dimension_;
-
- // Only ngrams of size ngram_size_ will be extracted.
- int ngram_size_;
-};
-
-} // namespace lang_id
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_LANG_ID_LANGUAGE_IDENTIFIER_FEATURES_H_
diff --git a/lang_id/light-sentence-features.cc b/lang_id/light-sentence-features.cc
deleted file mode 100644
index aec6b81..0000000
--- a/lang_id/light-sentence-features.cc
+++ /dev/null
@@ -1,29 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "lang_id/light-sentence-features.h"
-
-#include "common/registry.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-
-// Registry for the features on whole light sentences.
-TC_DEFINE_CLASS_REGISTRY_NAME("light sentence feature function",
- lang_id::LightSentenceFeature);
-
-} // namespace nlp_core
-} // namespace libtextclassifier
diff --git a/lang_id/light-sentence-features.h b/lang_id/light-sentence-features.h
deleted file mode 100644
index a140f65..0000000
--- a/lang_id/light-sentence-features.h
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_LANG_ID_LIGHT_SENTENCE_FEATURES_H_
-#define LIBTEXTCLASSIFIER_LANG_ID_LIGHT_SENTENCE_FEATURES_H_
-
-#include "common/feature-extractor.h"
-#include "lang_id/light-sentence.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-namespace lang_id {
-
-// Feature function that extracts features from LightSentences.
-typedef FeatureFunction<LightSentence> LightSentenceFeature;
-
-// Feature extractor for LightSentences.
-typedef FeatureExtractor<LightSentence> LightSentenceExtractor;
-
-} // namespace lang_id
-
-// Should be used in namespace libtextclassifier::nlp_core.
-TC_DECLARE_CLASS_REGISTRY_NAME(lang_id::LightSentenceFeature);
-
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_LANG_ID_LIGHT_SENTENCE_FEATURES_H_
diff --git a/lang_id/light-sentence.h b/lang_id/light-sentence.h
deleted file mode 100644
index e8451be..0000000
--- a/lang_id/light-sentence.h
+++ /dev/null
@@ -1,66 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_LANG_ID_LIGHT_SENTENCE_H_
-#define LIBTEXTCLASSIFIER_LANG_ID_LIGHT_SENTENCE_H_
-
-#include <string>
-#include <vector>
-
-#include "util/base/logging.h"
-#include "util/base/macros.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-namespace lang_id {
-
-// Simplified replacement for the Sentence proto, for internal use in the
-// language identification code.
-//
-// In this simplified form, a sentence is a vector of words, each word being a
-// string.
-class LightSentence {
- public:
- LightSentence() {}
-
- // Adds a new word after all existing ones, and returns a pointer to it. The
- // new word is initialized to the empty string.
- std::string *add_word() {
- words_.emplace_back();
- return &(words_.back());
- }
-
- // Returns number of words from this LightSentence.
- int num_words() const { return words_.size(); }
-
- // Returns the ith word from this LightSentence. Note: undefined behavior if
- // i is out of bounds.
- const std::string &word(int i) const {
- TC_DCHECK((i >= 0) && (i < num_words()));
- return words_[i];
- }
-
- private:
- std::vector<std::string> words_;
-
- TC_DISALLOW_COPY_AND_ASSIGN(LightSentence);
-};
-
-} // namespace lang_id
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_LANG_ID_LIGHT_SENTENCE_H_
diff --git a/lang_id/relevant-script-feature.cc b/lang_id/relevant-script-feature.cc
deleted file mode 100644
index c865ce5..0000000
--- a/lang_id/relevant-script-feature.cc
+++ /dev/null
@@ -1,91 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "lang_id/relevant-script-feature.h"
-
-#include <string>
-
-#include "common/feature-extractor.h"
-#include "common/feature-types.h"
-#include "common/task-context.h"
-#include "common/workspace.h"
-#include "lang_id/script-detector.h"
-#include "util/base/logging.h"
-#include "util/strings/utf8.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-namespace lang_id {
-
-bool RelevantScriptFeature::Setup(TaskContext *context) { return true; }
-
-bool RelevantScriptFeature::Init(TaskContext *context) {
- set_feature_type(new NumericFeatureType(name(), kNumRelevantScripts));
- return true;
-}
-
-void RelevantScriptFeature::Evaluate(const WorkspaceSet &workspaces,
- const LightSentence &sentence,
- FeatureVector *result) const {
- // We expect kNumRelevantScripts to be small, so we stack-allocate the array
- // of counts. Still, if that changes, we want to find out.
- static_assert(
- kNumRelevantScripts < 25,
- "switch counts to vector<int>: too big for stack-allocated int[]");
-
- // counts[s] is the number of characters with script s.
- // Note: {} "value-initializes" the array to zero.
- int counts[kNumRelevantScripts]{};
- int total_count = 0;
- for (int i = 0; i < sentence.num_words(); ++i) {
- const std::string &word = sentence.word(i);
- const char *const word_end = word.data() + word.size();
- const char *curr = word.data();
-
- // Skip over token start '^'.
- TC_DCHECK_EQ(*curr, '^');
- curr += GetNumBytesForNonZeroUTF8Char(curr);
- while (true) {
- const int num_bytes = GetNumBytesForNonZeroUTF8Char(curr);
- Script script = GetScript(curr, num_bytes);
-
- // We do this update and the if (...) break below *before* incrementing
- // counts[script] in order to skip the token end '$'.
- curr += num_bytes;
- if (curr >= word_end) {
- TC_DCHECK_EQ(*(curr - num_bytes), '$');
- break;
- }
- TC_DCHECK_GE(script, 0);
- TC_DCHECK_LT(script, kNumRelevantScripts);
- counts[script]++;
- total_count++;
- }
- }
-
- for (int script_id = 0; script_id < kNumRelevantScripts; ++script_id) {
- int count = counts[script_id];
- if (count > 0) {
- const float weight = static_cast<float>(count) / total_count;
- FloatFeatureValue value(script_id, weight);
- result->add(feature_type(), value.discrete_value);
- }
- }
-}
-
-} // namespace lang_id
-} // namespace nlp_core
-} // namespace libtextclassifier
diff --git a/lang_id/relevant-script-feature.h b/lang_id/relevant-script-feature.h
deleted file mode 100644
index 2aa2420..0000000
--- a/lang_id/relevant-script-feature.h
+++ /dev/null
@@ -1,58 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_LANG_ID_RELEVANT_SCRIPT_FEATURE_H_
-#define LIBTEXTCLASSIFIER_LANG_ID_RELEVANT_SCRIPT_FEATURE_H_
-
-#include "common/feature-extractor.h"
-#include "common/task-context.h"
-#include "common/workspace.h"
-#include "lang_id/light-sentence-features.h"
-#include "lang_id/light-sentence.h"
-
-namespace libtextclassifier {
-namespace nlp_core {
-namespace lang_id {
-
-// Given a sentence, generates one FloatFeatureValue for each "relevant" Unicode
-// script (see below): each such feature indicates the script and the ratio of
-// UTF8 characters in that script, in the given sentence.
-//
-// What is a relevant script? Recognizing all 100+ Unicode scripts would
-// require too much code size and runtime. Instead, we focus only on a few
-// scripts that communicate a lot of language information: e.g., the use of
-// Hiragana characters almost always indicates Japanese, so Hiragana is a
-// "relevant" script for us. The Latin script is used by dozens of language, so
-// Latin is not relevant in this context.
-class RelevantScriptFeature : public LightSentenceFeature {
- public:
- // Idiomatic SAFT Setup() and Init().
- bool Setup(TaskContext *context) override;
- bool Init(TaskContext *context) override;
-
- // Appends the features computed from the sentence to the feature vector.
- void Evaluate(const WorkspaceSet &workspaces, const LightSentence &sentence,
- FeatureVector *result) const override;
-
- TC_DEFINE_REGISTRATION_METHOD("continuous-bag-of-relevant-scripts",
- RelevantScriptFeature);
-};
-
-} // namespace lang_id
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_LANG_ID_RELEVANT_SCRIPT_FEATURE_H_
diff --git a/lang_id/script-detector.h b/lang_id/script-detector.h
deleted file mode 100644
index cf816ee..0000000
--- a/lang_id/script-detector.h
+++ /dev/null
@@ -1,162 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_LANG_ID_SCRIPT_DETECTOR_H_
-#define LIBTEXTCLASSIFIER_LANG_ID_SCRIPT_DETECTOR_H_
-
-namespace libtextclassifier {
-namespace nlp_core {
-namespace lang_id {
-
-// Unicode scripts we care about. To get compact and fast code, we detect only
-// a few Unicode scripts that offer a strong indication about the language of
-// the text (e.g., Hiragana -> Japanese).
-enum Script {
- // Special value to indicate internal errors in the script detection code.
- kScriptError,
-
- // Special values for all Unicode scripts that we do not detect. One special
- // value for Unicode characters of 1, 2, 3, respectively 4 bytes (as we
- // already have that information, we use it). kScriptOtherUtf8OneByte means
- // ~Latin and kScriptOtherUtf8FourBytes means ~Han.
- kScriptOtherUtf8OneByte,
- kScriptOtherUtf8TwoBytes,
- kScriptOtherUtf8ThreeBytes,
- kScriptOtherUtf8FourBytes,
-
- kScriptGreek,
- kScriptCyrillic,
- kScriptHebrew,
- kScriptArabic,
- kScriptHangulJamo, // Used primarily for Korean.
- kScriptHiragana, // Used primarily for Japanese.
- kScriptKatakana, // Used primarily for Japanese.
-
- // Add new scripts here.
-
- // Do not add any script after kNumRelevantScripts. This value indicates the
- // number of elements in this enum Script (except this value) such that we can
- // easily iterate over the scripts.
- kNumRelevantScripts,
-};
-
-template<typename IntType>
-inline bool InRange(IntType value, IntType low, IntType hi) {
- return (value >= low) && (value <= hi);
-}
-
-// Returns Script for the UTF8 character that starts at address p.
-// Precondition: p points to a valid UTF8 character of num_bytes bytes.
-inline Script GetScript(const unsigned char *p, int num_bytes) {
- switch (num_bytes) {
- case 1:
- return kScriptOtherUtf8OneByte;
-
- case 2: {
- // 2-byte UTF8 characters have 11 bits of information. unsigned int has
- // at least 16 bits (http://en.cppreference.com/w/cpp/language/types) so
- // it's enough. It's also usually the fastest int type on the current
- // CPU, so it's better to use than int32.
- static const unsigned int kGreekStart = 0x370;
-
- // Commented out (unsued in the code): kGreekEnd = 0x3FF;
- static const unsigned int kCyrillicStart = 0x400;
- static const unsigned int kCyrillicEnd = 0x4FF;
- static const unsigned int kHebrewStart = 0x590;
-
- // Commented out (unsued in the code): kHebrewEnd = 0x5FF;
- static const unsigned int kArabicStart = 0x600;
- static const unsigned int kArabicEnd = 0x6FF;
- const unsigned int codepoint = ((p[0] & 0x1F) << 6) | (p[1] & 0x3F);
- if (codepoint > kCyrillicEnd) {
- if (codepoint >= kArabicStart) {
- if (codepoint <= kArabicEnd) {
- return kScriptArabic;
- }
- } else {
- // At this point, codepoint < kArabicStart = kHebrewEnd + 1, so
- // codepoint <= kHebrewEnd.
- if (codepoint >= kHebrewStart) {
- return kScriptHebrew;
- }
- }
- } else {
- if (codepoint >= kCyrillicStart) {
- return kScriptCyrillic;
- } else {
- // At this point, codepoint < kCyrillicStart = kGreekEnd + 1, so
- // codepoint <= kGreekEnd.
- if (codepoint >= kGreekStart) {
- return kScriptGreek;
- }
- }
- }
- return kScriptOtherUtf8TwoBytes;
- }
-
- case 3: {
- // 3-byte UTF8 characters have 16 bits of information. unsigned int has
- // at least 16 bits.
- static const unsigned int kHangulJamoStart = 0x1100;
- static const unsigned int kHangulJamoEnd = 0x11FF;
- static const unsigned int kHiraganaStart = 0x3041;
- static const unsigned int kHiraganaEnd = 0x309F;
-
- // Commented out (unsued in the code): kKatakanaStart = 0x30A0;
- static const unsigned int kKatakanaEnd = 0x30FF;
- const unsigned int codepoint =
- ((p[0] & 0x0F) << 12) | ((p[1] & 0x3F) << 6) | (p[2] & 0x3F);
- if (codepoint > kHiraganaEnd) {
- // On this branch, codepoint > kHiraganaEnd = kKatakanaStart - 1, so
- // codepoint >= kKatakanaStart.
- if (codepoint <= kKatakanaEnd) {
- return kScriptKatakana;
- }
- } else {
- if (codepoint >= kHiraganaStart) {
- return kScriptHiragana;
- } else {
- if (InRange(codepoint, kHangulJamoStart, kHangulJamoEnd)) {
- return kScriptHangulJamo;
- }
- }
- }
- return kScriptOtherUtf8ThreeBytes;
- }
-
- case 4:
- return kScriptOtherUtf8FourBytes;
-
- default:
- return kScriptError;
- }
-}
-
-// Returns Script for the UTF8 character that starts at address p. Similar to
-// the previous version of GetScript, except for "char" vs "unsigned char".
-// Most code works with "char *" pointers, ignoring the fact that char is
-// unsigned (by default) on most platforms, but signed on iOS. This code takes
-// care of making sure we always treat chars as unsigned.
-inline Script GetScript(const char *p, int num_bytes) {
- return GetScript(reinterpret_cast<const unsigned char *>(p),
- num_bytes);
-}
-
-} // namespace lang_id
-} // namespace nlp_core
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_LANG_ID_SCRIPT_DETECTOR_H_
diff --git a/model-executor.cc b/model-executor.cc
new file mode 100644
index 0000000..2b1fc11
--- /dev/null
+++ b/model-executor.cc
@@ -0,0 +1,124 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "model-executor.h"
+
+#include "util/base/logging.h"
+
+namespace libtextclassifier2 {
+namespace internal {
+bool FromModelSpec(const tflite::Model* model_spec,
+ std::unique_ptr<tflite::FlatBufferModel>* model,
+ std::unique_ptr<tflite::Interpreter>* interpreter) {
+ *model = tflite::FlatBufferModel::BuildFromModel(model_spec);
+ if (!(*model) || !(*model)->initialized()) {
+ TC_LOG(ERROR) << "Could not build TFLite model from a model spec. ";
+ return false;
+ }
+
+ tflite::ops::builtin::BuiltinOpResolver builtins;
+ tflite::InterpreterBuilder(**model, builtins)(interpreter);
+ if (!interpreter) {
+ TC_LOG(ERROR) << "Could not build TFLite interpreter.";
+ return false;
+ }
+ return true;
+}
+} // namespace internal
+
+TFLiteEmbeddingExecutor::TFLiteEmbeddingExecutor(
+ const tflite::Model* model_spec) {
+ internal::FromModelSpec(model_spec, &model_, &interpreter_);
+ if (!interpreter_) {
+ return;
+ }
+ if (interpreter_->tensors_size() != 2) {
+ return;
+ }
+ embeddings_ = interpreter_->tensor(0);
+ if (embeddings_->dims->size != 2) {
+ return;
+ }
+ num_buckets_ = embeddings_->dims->data[0];
+ scales_ = interpreter_->tensor(1);
+ if (scales_->dims->size != 2 || scales_->dims->data[0] != num_buckets_ ||
+ scales_->dims->data[1] != 1) {
+ return;
+ }
+ embedding_size_ = embeddings_->dims->data[1];
+ initialized_ = true;
+}
+
+bool TFLiteEmbeddingExecutor::AddEmbedding(
+ const TensorView<int>& sparse_features, float* dest, int dest_size) {
+ if (!initialized_ || dest_size != embedding_size_) {
+ return false;
+ }
+ const int num_sparse_features = sparse_features.size();
+ for (int i = 0; i < num_sparse_features; ++i) {
+ const int bucket_id = sparse_features.data()[i];
+ if (bucket_id >= num_buckets_) {
+ return false;
+ }
+ const float multiplier = scales_->data.f[bucket_id];
+ for (int k = 0; k < embedding_size_; ++k) {
+ // Dequantize and add the embedding.
+ dest[k] +=
+ 1.0 / num_sparse_features *
+ (static_cast<int>(
+ embeddings_->data.uint8[bucket_id * embedding_size_ + k]) -
+ kQuantBias) *
+ multiplier;
+ }
+ }
+ return true;
+}
+
+TensorView<float> ComputeLogitsHelper(const int input_index_features,
+ const int output_index_logits,
+ const TensorView<float>& features,
+ tflite::Interpreter* interpreter) {
+ interpreter->ResizeInputTensor(input_index_features, features.shape());
+ if (interpreter->AllocateTensors() != kTfLiteOk) {
+ TC_VLOG(1) << "Allocation failed.";
+ return TensorView<float>::Invalid();
+ }
+
+ TfLiteTensor* features_tensor =
+ interpreter->tensor(interpreter->inputs()[input_index_features]);
+ int size = 1;
+ for (int i = 0; i < features_tensor->dims->size; ++i) {
+ size *= features_tensor->dims->data[i];
+ }
+ features.copy_to(features_tensor->data.f, size);
+
+ if (interpreter->Invoke() != kTfLiteOk) {
+ TC_VLOG(1) << "Interpreter failed.";
+ return TensorView<float>::Invalid();
+ }
+
+ TfLiteTensor* logits_tensor =
+ interpreter->tensor(interpreter->outputs()[output_index_logits]);
+
+ std::vector<int> output_shape(logits_tensor->dims->size);
+ for (int i = 0; i < logits_tensor->dims->size; ++i) {
+ output_shape[i] = logits_tensor->dims->data[i];
+ }
+
+ return TensorView<float>(logits_tensor->data.f, output_shape);
+}
+
+} // namespace libtextclassifier2
diff --git a/model-executor.h b/model-executor.h
new file mode 100644
index 0000000..b16d53d
--- /dev/null
+++ b/model-executor.h
@@ -0,0 +1,104 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Contains classes that can execute different models/parts of a model.
+
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_MODEL_EXECUTOR_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_MODEL_EXECUTOR_H_
+
+#include <memory>
+
+#include "tensor-view.h"
+#include "types.h"
+#include "util/base/logging.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace libtextclassifier2 {
+
+namespace internal {
+bool FromModelSpec(const tflite::Model* model_spec,
+ std::unique_ptr<tflite::FlatBufferModel>* model,
+ std::unique_ptr<tflite::Interpreter>* interpreter);
+} // namespace internal
+
+// A helper function that given indices of feature and logits tensor, feature
+// values computes the logits using given interpreter.
+TensorView<float> ComputeLogitsHelper(const int input_index_features,
+ const int output_index_logits,
+ const TensorView<float>& features,
+ tflite::Interpreter* interpreter);
+
+// Executor for the text selection prediction and classification models.
+// NOTE: This class is not thread-safe.
+class ModelExecutor {
+ public:
+ explicit ModelExecutor(const tflite::Model* model_spec) {
+ internal::FromModelSpec(model_spec, &model_, &interpreter_);
+ }
+
+ TensorView<float> ComputeLogits(const TensorView<float>& features) {
+ return ComputeLogitsHelper(kInputIndexFeatures, kOutputIndexLogits,
+ features, interpreter_.get());
+ }
+
+ protected:
+ static const int kInputIndexFeatures = 0;
+ static const int kOutputIndexLogits = 0;
+
+ std::unique_ptr<tflite::FlatBufferModel> model_ = nullptr;
+ std::unique_ptr<tflite::Interpreter> interpreter_ = nullptr;
+};
+
+// Executor for embedding sparse features into a dense vector.
+class EmbeddingExecutor {
+ public:
+ virtual ~EmbeddingExecutor() {}
+
+ // Embeds the sparse_features into a dense embedding and adds (+) it
+ // element-wise to the dest vector.
+ virtual bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
+ int dest_size) = 0;
+
+ // Returns true when the model is ready to be used, false otherwise.
+ virtual bool IsReady() { return true; }
+};
+
+// NOTE: This class is not thread-safe.
+class TFLiteEmbeddingExecutor : public EmbeddingExecutor {
+ public:
+ explicit TFLiteEmbeddingExecutor(const tflite::Model* model_spec);
+ bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
+ int dest_size) override;
+
+ bool IsReady() override { return initialized_; }
+
+ protected:
+ static const int kQuantBias = 128;
+ bool initialized_ = false;
+ int num_buckets_ = -1;
+ int embedding_size_ = -1;
+ const TfLiteTensor* scales_ = nullptr;
+ const TfLiteTensor* embeddings_ = nullptr;
+
+ std::unique_ptr<tflite::FlatBufferModel> model_ = nullptr;
+ std::unique_ptr<tflite::Interpreter> interpreter_ = nullptr;
+};
+
+} // namespace libtextclassifier2
+
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_MODEL_EXECUTOR_H_
diff --git a/model.fbs b/model.fbs
new file mode 100755
index 0000000..d98e5ac
--- /dev/null
+++ b/model.fbs
@@ -0,0 +1,145 @@
+// Generated from model.proto
+
+namespace libtextclassifier2.TokenizationCodepointRange_;
+
+enum Role : int {
+ DEFAULT_ROLE = 0,
+ SPLIT_BEFORE = 1,
+ SPLIT_AFTER = 2,
+ TOKEN_SEPARATOR = 3,
+ DISCARD_CODEPOINT = 4,
+ WHITESPACE_SEPARATOR = 7,
+}
+
+namespace libtextclassifier2.FeatureProcessorOptions_;
+
+enum CenterTokenSelectionMethod : int {
+ DEFAULT_CENTER_TOKEN_METHOD = 0,
+ CENTER_TOKEN_FROM_CLICK = 1,
+ CENTER_TOKEN_MIDDLE_OF_SELECTION = 2,
+}
+
+enum TokenizationType : int {
+ INVALID_TOKENIZATION_TYPE = 0,
+ INTERNAL_TOKENIZER = 1,
+ ICU = 2,
+ MIXED = 3,
+}
+
+namespace libtextclassifier2;
+
+table SelectionModelOptions {
+ strip_unpaired_brackets:bool;
+ symmetry_context_size:int;
+}
+
+table ClassificationModelOptions {
+ phone_min_num_digits:int = 7;
+ phone_max_num_digits:int = 15;
+}
+
+table RegexModelOptions {
+ patterns:[libtextclassifier2.RegexModelOptions_.Pattern];
+}
+
+namespace libtextclassifier2.RegexModelOptions_;
+
+table Pattern {
+ collection_name:string;
+ pattern:string;
+}
+
+namespace libtextclassifier2;
+
+table StructuredRegexModel {
+ patterns:[libtextclassifier2.StructuredRegexModel_.StructuredPattern];
+}
+
+namespace libtextclassifier2.StructuredRegexModel_;
+
+table StructuredPattern {
+ pattern:string;
+ node_names:[string];
+}
+
+namespace libtextclassifier2;
+
+table Model {
+ language:string;
+ version:int;
+ selection_feature_options:libtextclassifier2.FeatureProcessorOptions;
+ classification_feature_options:libtextclassifier2.FeatureProcessorOptions;
+ selection_model:[ubyte];
+ classification_model:[ubyte];
+ embedding_model:[ubyte];
+ regex_options:libtextclassifier2.RegexModelOptions;
+ selection_options:libtextclassifier2.SelectionModelOptions;
+ classification_options:libtextclassifier2.ClassificationModelOptions;
+ regex_model:libtextclassifier2.StructuredRegexModel;
+}
+
+table TokenizationCodepointRange {
+ start:int;
+ end:int;
+ role:libtextclassifier2.TokenizationCodepointRange_.Role;
+ script_id:int;
+}
+
+table FeatureProcessorOptions {
+ num_buckets:int = -1;
+ embedding_size:int = -1;
+ context_size:int = -1;
+ max_selection_span:int = -1;
+ chargram_orders:[int];
+ max_word_length:int = 20;
+ unicode_aware_features:bool;
+ extract_case_feature:bool;
+ extract_selection_mask_feature:bool;
+ regexp_feature:[string];
+ remap_digits:bool;
+ lowercase_tokens:bool;
+ selection_reduced_output_space:bool;
+ collections:[string];
+ default_collection:int = -1;
+ only_use_line_with_click:bool;
+ split_tokens_on_selection_boundaries:bool;
+ tokenization_codepoint_config:[libtextclassifier2.TokenizationCodepointRange];
+ center_token_selection_method:libtextclassifier2.FeatureProcessorOptions_.CenterTokenSelectionMethod;
+ snap_label_span_boundaries_to_containing_tokens:bool;
+ supported_codepoint_ranges:[libtextclassifier2.FeatureProcessorOptions_.CodepointRange];
+ internal_tokenizer_codepoint_ranges:[libtextclassifier2.FeatureProcessorOptions_.CodepointRange];
+ min_supported_codepoint_ratio:float = 0.0;
+ feature_version:int;
+ tokenization_type:libtextclassifier2.FeatureProcessorOptions_.TokenizationType;
+ icu_preserve_whitespace_tokens:bool;
+ ignored_span_boundary_codepoints:[int];
+ click_random_token_in_selection:bool;
+ alternative_collection_map:[libtextclassifier2.FeatureProcessorOptions_.CollectionMapEntry];
+ bounds_sensitive_features:libtextclassifier2.FeatureProcessorOptions_.BoundsSensitiveFeatures;
+ split_selection_candidates:bool;
+ allowed_chargrams:[string];
+ tokenize_on_script_change:bool;
+}
+
+namespace libtextclassifier2.FeatureProcessorOptions_;
+
+table CodepointRange {
+ start:int;
+ end:int;
+}
+
+table CollectionMapEntry {
+ key:string;
+ value:string;
+}
+
+table BoundsSensitiveFeatures {
+ enabled:bool;
+ num_tokens_before:int;
+ num_tokens_inside_left:int;
+ num_tokens_inside_right:int;
+ num_tokens_after:int;
+ include_inside_bag:bool;
+ include_inside_length:bool;
+}
+
diff --git a/model_generated.h b/model_generated.h
new file mode 100755
index 0000000..fd11c39
--- /dev/null
+++ b/model_generated.h
@@ -0,0 +1,2189 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// automatically generated by the FlatBuffers compiler, do not modify
+
+
+#ifndef FLATBUFFERS_GENERATED_MODEL_LIBTEXTCLASSIFIER2_FEATUREPROCESSOROPTIONS__H_
+#define FLATBUFFERS_GENERATED_MODEL_LIBTEXTCLASSIFIER2_FEATUREPROCESSOROPTIONS__H_
+
+#include "flatbuffers/flatbuffers.h"
+
+namespace libtextclassifier2 {
+
+struct SelectionModelOptions;
+struct SelectionModelOptionsT;
+
+struct ClassificationModelOptions;
+struct ClassificationModelOptionsT;
+
+struct RegexModelOptions;
+struct RegexModelOptionsT;
+
+namespace RegexModelOptions_ {
+
+struct Pattern;
+struct PatternT;
+
+} // namespace RegexModelOptions_
+
+struct StructuredRegexModel;
+struct StructuredRegexModelT;
+
+namespace StructuredRegexModel_ {
+
+struct StructuredPattern;
+struct StructuredPatternT;
+
+} // namespace StructuredRegexModel_
+
+struct Model;
+struct ModelT;
+
+struct TokenizationCodepointRange;
+struct TokenizationCodepointRangeT;
+
+struct FeatureProcessorOptions;
+struct FeatureProcessorOptionsT;
+
+namespace FeatureProcessorOptions_ {
+
+struct CodepointRange;
+struct CodepointRangeT;
+
+struct CollectionMapEntry;
+struct CollectionMapEntryT;
+
+struct BoundsSensitiveFeatures;
+struct BoundsSensitiveFeaturesT;
+
+} // namespace FeatureProcessorOptions_
+
+namespace TokenizationCodepointRange_ {
+
+enum Role {
+ Role_DEFAULT_ROLE = 0,
+ Role_SPLIT_BEFORE = 1,
+ Role_SPLIT_AFTER = 2,
+ Role_TOKEN_SEPARATOR = 3,
+ Role_DISCARD_CODEPOINT = 4,
+ Role_WHITESPACE_SEPARATOR = 7,
+ Role_MIN = Role_DEFAULT_ROLE,
+ Role_MAX = Role_WHITESPACE_SEPARATOR
+};
+
+inline Role (&EnumValuesRole())[6] {
+ static Role values[] = {
+ Role_DEFAULT_ROLE,
+ Role_SPLIT_BEFORE,
+ Role_SPLIT_AFTER,
+ Role_TOKEN_SEPARATOR,
+ Role_DISCARD_CODEPOINT,
+ Role_WHITESPACE_SEPARATOR
+ };
+ return values;
+}
+
+inline const char **EnumNamesRole() {
+ static const char *names[] = {
+ "DEFAULT_ROLE",
+ "SPLIT_BEFORE",
+ "SPLIT_AFTER",
+ "TOKEN_SEPARATOR",
+ "DISCARD_CODEPOINT",
+ "",
+ "",
+ "WHITESPACE_SEPARATOR",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameRole(Role e) {
+ const size_t index = static_cast<int>(e);
+ return EnumNamesRole()[index];
+}
+
+} // namespace TokenizationCodepointRange_
+
+namespace FeatureProcessorOptions_ {
+
+enum CenterTokenSelectionMethod {
+ CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD = 0,
+ CenterTokenSelectionMethod_CENTER_TOKEN_FROM_CLICK = 1,
+ CenterTokenSelectionMethod_CENTER_TOKEN_MIDDLE_OF_SELECTION = 2,
+ CenterTokenSelectionMethod_MIN = CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD,
+ CenterTokenSelectionMethod_MAX = CenterTokenSelectionMethod_CENTER_TOKEN_MIDDLE_OF_SELECTION
+};
+
+inline CenterTokenSelectionMethod (&EnumValuesCenterTokenSelectionMethod())[3] {
+ static CenterTokenSelectionMethod values[] = {
+ CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD,
+ CenterTokenSelectionMethod_CENTER_TOKEN_FROM_CLICK,
+ CenterTokenSelectionMethod_CENTER_TOKEN_MIDDLE_OF_SELECTION
+ };
+ return values;
+}
+
+inline const char **EnumNamesCenterTokenSelectionMethod() {
+ static const char *names[] = {
+ "DEFAULT_CENTER_TOKEN_METHOD",
+ "CENTER_TOKEN_FROM_CLICK",
+ "CENTER_TOKEN_MIDDLE_OF_SELECTION",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameCenterTokenSelectionMethod(CenterTokenSelectionMethod e) {
+ const size_t index = static_cast<int>(e);
+ return EnumNamesCenterTokenSelectionMethod()[index];
+}
+
+enum TokenizationType {
+ TokenizationType_INVALID_TOKENIZATION_TYPE = 0,
+ TokenizationType_INTERNAL_TOKENIZER = 1,
+ TokenizationType_ICU = 2,
+ TokenizationType_MIXED = 3,
+ TokenizationType_MIN = TokenizationType_INVALID_TOKENIZATION_TYPE,
+ TokenizationType_MAX = TokenizationType_MIXED
+};
+
+inline TokenizationType (&EnumValuesTokenizationType())[4] {
+ static TokenizationType values[] = {
+ TokenizationType_INVALID_TOKENIZATION_TYPE,
+ TokenizationType_INTERNAL_TOKENIZER,
+ TokenizationType_ICU,
+ TokenizationType_MIXED
+ };
+ return values;
+}
+
+inline const char **EnumNamesTokenizationType() {
+ static const char *names[] = {
+ "INVALID_TOKENIZATION_TYPE",
+ "INTERNAL_TOKENIZER",
+ "ICU",
+ "MIXED",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameTokenizationType(TokenizationType e) {
+ const size_t index = static_cast<int>(e);
+ return EnumNamesTokenizationType()[index];
+}
+
+} // namespace FeatureProcessorOptions_
+
+struct SelectionModelOptionsT : public flatbuffers::NativeTable {
+ typedef SelectionModelOptions TableType;
+ bool strip_unpaired_brackets;
+ int32_t symmetry_context_size;
+ SelectionModelOptionsT()
+ : strip_unpaired_brackets(false),
+ symmetry_context_size(0) {
+ }
+};
+
+struct SelectionModelOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef SelectionModelOptionsT NativeTableType;
+ enum {
+ VT_STRIP_UNPAIRED_BRACKETS = 4,
+ VT_SYMMETRY_CONTEXT_SIZE = 6
+ };
+ bool strip_unpaired_brackets() const {
+ return GetField<uint8_t>(VT_STRIP_UNPAIRED_BRACKETS, 0) != 0;
+ }
+ int32_t symmetry_context_size() const {
+ return GetField<int32_t>(VT_SYMMETRY_CONTEXT_SIZE, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint8_t>(verifier, VT_STRIP_UNPAIRED_BRACKETS) &&
+ VerifyField<int32_t>(verifier, VT_SYMMETRY_CONTEXT_SIZE) &&
+ verifier.EndTable();
+ }
+ SelectionModelOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(SelectionModelOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<SelectionModelOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const SelectionModelOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct SelectionModelOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_strip_unpaired_brackets(bool strip_unpaired_brackets) {
+ fbb_.AddElement<uint8_t>(SelectionModelOptions::VT_STRIP_UNPAIRED_BRACKETS, static_cast<uint8_t>(strip_unpaired_brackets), 0);
+ }
+ void add_symmetry_context_size(int32_t symmetry_context_size) {
+ fbb_.AddElement<int32_t>(SelectionModelOptions::VT_SYMMETRY_CONTEXT_SIZE, symmetry_context_size, 0);
+ }
+ explicit SelectionModelOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ SelectionModelOptionsBuilder &operator=(const SelectionModelOptionsBuilder &);
+ flatbuffers::Offset<SelectionModelOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<SelectionModelOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<SelectionModelOptions> CreateSelectionModelOptions(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ bool strip_unpaired_brackets = false,
+ int32_t symmetry_context_size = 0) {
+ SelectionModelOptionsBuilder builder_(_fbb);
+ builder_.add_symmetry_context_size(symmetry_context_size);
+ builder_.add_strip_unpaired_brackets(strip_unpaired_brackets);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<SelectionModelOptions> CreateSelectionModelOptions(flatbuffers::FlatBufferBuilder &_fbb, const SelectionModelOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct ClassificationModelOptionsT : public flatbuffers::NativeTable {
+ typedef ClassificationModelOptions TableType;
+ int32_t phone_min_num_digits;
+ int32_t phone_max_num_digits;
+ ClassificationModelOptionsT()
+ : phone_min_num_digits(7),
+ phone_max_num_digits(15) {
+ }
+};
+
+struct ClassificationModelOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef ClassificationModelOptionsT NativeTableType;
+ enum {
+ VT_PHONE_MIN_NUM_DIGITS = 4,
+ VT_PHONE_MAX_NUM_DIGITS = 6
+ };
+ int32_t phone_min_num_digits() const {
+ return GetField<int32_t>(VT_PHONE_MIN_NUM_DIGITS, 7);
+ }
+ int32_t phone_max_num_digits() const {
+ return GetField<int32_t>(VT_PHONE_MAX_NUM_DIGITS, 15);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_PHONE_MIN_NUM_DIGITS) &&
+ VerifyField<int32_t>(verifier, VT_PHONE_MAX_NUM_DIGITS) &&
+ verifier.EndTable();
+ }
+ ClassificationModelOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(ClassificationModelOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<ClassificationModelOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ClassificationModelOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct ClassificationModelOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_phone_min_num_digits(int32_t phone_min_num_digits) {
+ fbb_.AddElement<int32_t>(ClassificationModelOptions::VT_PHONE_MIN_NUM_DIGITS, phone_min_num_digits, 7);
+ }
+ void add_phone_max_num_digits(int32_t phone_max_num_digits) {
+ fbb_.AddElement<int32_t>(ClassificationModelOptions::VT_PHONE_MAX_NUM_DIGITS, phone_max_num_digits, 15);
+ }
+ explicit ClassificationModelOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ClassificationModelOptionsBuilder &operator=(const ClassificationModelOptionsBuilder &);
+ flatbuffers::Offset<ClassificationModelOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ClassificationModelOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ClassificationModelOptions> CreateClassificationModelOptions(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t phone_min_num_digits = 7,
+ int32_t phone_max_num_digits = 15) {
+ ClassificationModelOptionsBuilder builder_(_fbb);
+ builder_.add_phone_max_num_digits(phone_max_num_digits);
+ builder_.add_phone_min_num_digits(phone_min_num_digits);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<ClassificationModelOptions> CreateClassificationModelOptions(flatbuffers::FlatBufferBuilder &_fbb, const ClassificationModelOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct RegexModelOptionsT : public flatbuffers::NativeTable {
+ typedef RegexModelOptions TableType;
+ std::vector<std::unique_ptr<libtextclassifier2::RegexModelOptions_::PatternT>> patterns;
+ RegexModelOptionsT() {
+ }
+};
+
+struct RegexModelOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef RegexModelOptionsT NativeTableType;
+ enum {
+ VT_PATTERNS = 4
+ };
+ const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::RegexModelOptions_::Pattern>> *patterns() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::RegexModelOptions_::Pattern>> *>(VT_PATTERNS);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_PATTERNS) &&
+ verifier.Verify(patterns()) &&
+ verifier.VerifyVectorOfTables(patterns()) &&
+ verifier.EndTable();
+ }
+ RegexModelOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(RegexModelOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<RegexModelOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const RegexModelOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct RegexModelOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_patterns(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::RegexModelOptions_::Pattern>>> patterns) {
+ fbb_.AddOffset(RegexModelOptions::VT_PATTERNS, patterns);
+ }
+ explicit RegexModelOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ RegexModelOptionsBuilder &operator=(const RegexModelOptionsBuilder &);
+ flatbuffers::Offset<RegexModelOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<RegexModelOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<RegexModelOptions> CreateRegexModelOptions(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::RegexModelOptions_::Pattern>>> patterns = 0) {
+ RegexModelOptionsBuilder builder_(_fbb);
+ builder_.add_patterns(patterns);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<RegexModelOptions> CreateRegexModelOptionsDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<flatbuffers::Offset<libtextclassifier2::RegexModelOptions_::Pattern>> *patterns = nullptr) {
+ return libtextclassifier2::CreateRegexModelOptions(
+ _fbb,
+ patterns ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::RegexModelOptions_::Pattern>>(*patterns) : 0);
+}
+
+flatbuffers::Offset<RegexModelOptions> CreateRegexModelOptions(flatbuffers::FlatBufferBuilder &_fbb, const RegexModelOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+namespace RegexModelOptions_ {
+
+struct PatternT : public flatbuffers::NativeTable {
+ typedef Pattern TableType;
+ std::string collection_name;
+ std::string pattern;
+ PatternT() {
+ }
+};
+
+struct Pattern FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PatternT NativeTableType;
+ enum {
+ VT_COLLECTION_NAME = 4,
+ VT_PATTERN = 6
+ };
+ const flatbuffers::String *collection_name() const {
+ return GetPointer<const flatbuffers::String *>(VT_COLLECTION_NAME);
+ }
+ const flatbuffers::String *pattern() const {
+ return GetPointer<const flatbuffers::String *>(VT_PATTERN);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_COLLECTION_NAME) &&
+ verifier.Verify(collection_name()) &&
+ VerifyOffset(verifier, VT_PATTERN) &&
+ verifier.Verify(pattern()) &&
+ verifier.EndTable();
+ }
+ PatternT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PatternT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<Pattern> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PatternT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PatternBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_collection_name(flatbuffers::Offset<flatbuffers::String> collection_name) {
+ fbb_.AddOffset(Pattern::VT_COLLECTION_NAME, collection_name);
+ }
+ void add_pattern(flatbuffers::Offset<flatbuffers::String> pattern) {
+ fbb_.AddOffset(Pattern::VT_PATTERN, pattern);
+ }
+ explicit PatternBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PatternBuilder &operator=(const PatternBuilder &);
+ flatbuffers::Offset<Pattern> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Pattern>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Pattern> CreatePattern(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> collection_name = 0,
+ flatbuffers::Offset<flatbuffers::String> pattern = 0) {
+ PatternBuilder builder_(_fbb);
+ builder_.add_pattern(pattern);
+ builder_.add_collection_name(collection_name);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<Pattern> CreatePatternDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *collection_name = nullptr,
+ const char *pattern = nullptr) {
+ return libtextclassifier2::RegexModelOptions_::CreatePattern(
+ _fbb,
+ collection_name ? _fbb.CreateString(collection_name) : 0,
+ pattern ? _fbb.CreateString(pattern) : 0);
+}
+
+flatbuffers::Offset<Pattern> CreatePattern(flatbuffers::FlatBufferBuilder &_fbb, const PatternT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+} // namespace RegexModelOptions_
+
+struct StructuredRegexModelT : public flatbuffers::NativeTable {
+ typedef StructuredRegexModel TableType;
+ std::vector<std::unique_ptr<libtextclassifier2::StructuredRegexModel_::StructuredPatternT>> patterns;
+ StructuredRegexModelT() {
+ }
+};
+
+struct StructuredRegexModel FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef StructuredRegexModelT NativeTableType;
+ enum {
+ VT_PATTERNS = 4
+ };
+ const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::StructuredRegexModel_::StructuredPattern>> *patterns() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::StructuredRegexModel_::StructuredPattern>> *>(VT_PATTERNS);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_PATTERNS) &&
+ verifier.Verify(patterns()) &&
+ verifier.VerifyVectorOfTables(patterns()) &&
+ verifier.EndTable();
+ }
+ StructuredRegexModelT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(StructuredRegexModelT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<StructuredRegexModel> Pack(flatbuffers::FlatBufferBuilder &_fbb, const StructuredRegexModelT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct StructuredRegexModelBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_patterns(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::StructuredRegexModel_::StructuredPattern>>> patterns) {
+ fbb_.AddOffset(StructuredRegexModel::VT_PATTERNS, patterns);
+ }
+ explicit StructuredRegexModelBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ StructuredRegexModelBuilder &operator=(const StructuredRegexModelBuilder &);
+ flatbuffers::Offset<StructuredRegexModel> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<StructuredRegexModel>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<StructuredRegexModel> CreateStructuredRegexModel(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::StructuredRegexModel_::StructuredPattern>>> patterns = 0) {
+ StructuredRegexModelBuilder builder_(_fbb);
+ builder_.add_patterns(patterns);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<StructuredRegexModel> CreateStructuredRegexModelDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<flatbuffers::Offset<libtextclassifier2::StructuredRegexModel_::StructuredPattern>> *patterns = nullptr) {
+ return libtextclassifier2::CreateStructuredRegexModel(
+ _fbb,
+ patterns ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::StructuredRegexModel_::StructuredPattern>>(*patterns) : 0);
+}
+
+flatbuffers::Offset<StructuredRegexModel> CreateStructuredRegexModel(flatbuffers::FlatBufferBuilder &_fbb, const StructuredRegexModelT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+namespace StructuredRegexModel_ {
+
+struct StructuredPatternT : public flatbuffers::NativeTable {
+ typedef StructuredPattern TableType;
+ std::string pattern;
+ std::vector<std::string> node_names;
+ StructuredPatternT() {
+ }
+};
+
+struct StructuredPattern FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef StructuredPatternT NativeTableType;
+ enum {
+ VT_PATTERN = 4,
+ VT_NODE_NAMES = 6
+ };
+ const flatbuffers::String *pattern() const {
+ return GetPointer<const flatbuffers::String *>(VT_PATTERN);
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *node_names() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_NODE_NAMES);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_PATTERN) &&
+ verifier.Verify(pattern()) &&
+ VerifyOffset(verifier, VT_NODE_NAMES) &&
+ verifier.Verify(node_names()) &&
+ verifier.VerifyVectorOfStrings(node_names()) &&
+ verifier.EndTable();
+ }
+ StructuredPatternT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(StructuredPatternT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<StructuredPattern> Pack(flatbuffers::FlatBufferBuilder &_fbb, const StructuredPatternT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct StructuredPatternBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_pattern(flatbuffers::Offset<flatbuffers::String> pattern) {
+ fbb_.AddOffset(StructuredPattern::VT_PATTERN, pattern);
+ }
+ void add_node_names(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> node_names) {
+ fbb_.AddOffset(StructuredPattern::VT_NODE_NAMES, node_names);
+ }
+ explicit StructuredPatternBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ StructuredPatternBuilder &operator=(const StructuredPatternBuilder &);
+ flatbuffers::Offset<StructuredPattern> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<StructuredPattern>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<StructuredPattern> CreateStructuredPattern(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> pattern = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> node_names = 0) {
+ StructuredPatternBuilder builder_(_fbb);
+ builder_.add_node_names(node_names);
+ builder_.add_pattern(pattern);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<StructuredPattern> CreateStructuredPatternDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *pattern = nullptr,
+ const std::vector<flatbuffers::Offset<flatbuffers::String>> *node_names = nullptr) {
+ return libtextclassifier2::StructuredRegexModel_::CreateStructuredPattern(
+ _fbb,
+ pattern ? _fbb.CreateString(pattern) : 0,
+ node_names ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*node_names) : 0);
+}
+
+flatbuffers::Offset<StructuredPattern> CreateStructuredPattern(flatbuffers::FlatBufferBuilder &_fbb, const StructuredPatternT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+} // namespace StructuredRegexModel_
+
+struct ModelT : public flatbuffers::NativeTable {
+ typedef Model TableType;
+ std::string language;
+ int32_t version;
+ std::unique_ptr<FeatureProcessorOptionsT> selection_feature_options;
+ std::unique_ptr<FeatureProcessorOptionsT> classification_feature_options;
+ std::vector<uint8_t> selection_model;
+ std::vector<uint8_t> classification_model;
+ std::vector<uint8_t> embedding_model;
+ std::unique_ptr<RegexModelOptionsT> regex_options;
+ std::unique_ptr<SelectionModelOptionsT> selection_options;
+ std::unique_ptr<ClassificationModelOptionsT> classification_options;
+ std::unique_ptr<StructuredRegexModelT> regex_model;
+ ModelT()
+ : version(0) {
+ }
+};
+
+struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef ModelT NativeTableType;
+ enum {
+ VT_LANGUAGE = 4,
+ VT_VERSION = 6,
+ VT_SELECTION_FEATURE_OPTIONS = 8,
+ VT_CLASSIFICATION_FEATURE_OPTIONS = 10,
+ VT_SELECTION_MODEL = 12,
+ VT_CLASSIFICATION_MODEL = 14,
+ VT_EMBEDDING_MODEL = 16,
+ VT_REGEX_OPTIONS = 18,
+ VT_SELECTION_OPTIONS = 20,
+ VT_CLASSIFICATION_OPTIONS = 22,
+ VT_REGEX_MODEL = 24
+ };
+ const flatbuffers::String *language() const {
+ return GetPointer<const flatbuffers::String *>(VT_LANGUAGE);
+ }
+ int32_t version() const {
+ return GetField<int32_t>(VT_VERSION, 0);
+ }
+ const FeatureProcessorOptions *selection_feature_options() const {
+ return GetPointer<const FeatureProcessorOptions *>(VT_SELECTION_FEATURE_OPTIONS);
+ }
+ const FeatureProcessorOptions *classification_feature_options() const {
+ return GetPointer<const FeatureProcessorOptions *>(VT_CLASSIFICATION_FEATURE_OPTIONS);
+ }
+ const flatbuffers::Vector<uint8_t> *selection_model() const {
+ return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_SELECTION_MODEL);
+ }
+ const flatbuffers::Vector<uint8_t> *classification_model() const {
+ return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CLASSIFICATION_MODEL);
+ }
+ const flatbuffers::Vector<uint8_t> *embedding_model() const {
+ return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_EMBEDDING_MODEL);
+ }
+ const RegexModelOptions *regex_options() const {
+ return GetPointer<const RegexModelOptions *>(VT_REGEX_OPTIONS);
+ }
+ const SelectionModelOptions *selection_options() const {
+ return GetPointer<const SelectionModelOptions *>(VT_SELECTION_OPTIONS);
+ }
+ const ClassificationModelOptions *classification_options() const {
+ return GetPointer<const ClassificationModelOptions *>(VT_CLASSIFICATION_OPTIONS);
+ }
+ const StructuredRegexModel *regex_model() const {
+ return GetPointer<const StructuredRegexModel *>(VT_REGEX_MODEL);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_LANGUAGE) &&
+ verifier.Verify(language()) &&
+ VerifyField<int32_t>(verifier, VT_VERSION) &&
+ VerifyOffset(verifier, VT_SELECTION_FEATURE_OPTIONS) &&
+ verifier.VerifyTable(selection_feature_options()) &&
+ VerifyOffset(verifier, VT_CLASSIFICATION_FEATURE_OPTIONS) &&
+ verifier.VerifyTable(classification_feature_options()) &&
+ VerifyOffset(verifier, VT_SELECTION_MODEL) &&
+ verifier.Verify(selection_model()) &&
+ VerifyOffset(verifier, VT_CLASSIFICATION_MODEL) &&
+ verifier.Verify(classification_model()) &&
+ VerifyOffset(verifier, VT_EMBEDDING_MODEL) &&
+ verifier.Verify(embedding_model()) &&
+ VerifyOffset(verifier, VT_REGEX_OPTIONS) &&
+ verifier.VerifyTable(regex_options()) &&
+ VerifyOffset(verifier, VT_SELECTION_OPTIONS) &&
+ verifier.VerifyTable(selection_options()) &&
+ VerifyOffset(verifier, VT_CLASSIFICATION_OPTIONS) &&
+ verifier.VerifyTable(classification_options()) &&
+ VerifyOffset(verifier, VT_REGEX_MODEL) &&
+ verifier.VerifyTable(regex_model()) &&
+ verifier.EndTable();
+ }
+ ModelT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(ModelT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<Model> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ModelT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct ModelBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_language(flatbuffers::Offset<flatbuffers::String> language) {
+ fbb_.AddOffset(Model::VT_LANGUAGE, language);
+ }
+ void add_version(int32_t version) {
+ fbb_.AddElement<int32_t>(Model::VT_VERSION, version, 0);
+ }
+ void add_selection_feature_options(flatbuffers::Offset<FeatureProcessorOptions> selection_feature_options) {
+ fbb_.AddOffset(Model::VT_SELECTION_FEATURE_OPTIONS, selection_feature_options);
+ }
+ void add_classification_feature_options(flatbuffers::Offset<FeatureProcessorOptions> classification_feature_options) {
+ fbb_.AddOffset(Model::VT_CLASSIFICATION_FEATURE_OPTIONS, classification_feature_options);
+ }
+ void add_selection_model(flatbuffers::Offset<flatbuffers::Vector<uint8_t>> selection_model) {
+ fbb_.AddOffset(Model::VT_SELECTION_MODEL, selection_model);
+ }
+ void add_classification_model(flatbuffers::Offset<flatbuffers::Vector<uint8_t>> classification_model) {
+ fbb_.AddOffset(Model::VT_CLASSIFICATION_MODEL, classification_model);
+ }
+ void add_embedding_model(flatbuffers::Offset<flatbuffers::Vector<uint8_t>> embedding_model) {
+ fbb_.AddOffset(Model::VT_EMBEDDING_MODEL, embedding_model);
+ }
+ void add_regex_options(flatbuffers::Offset<RegexModelOptions> regex_options) {
+ fbb_.AddOffset(Model::VT_REGEX_OPTIONS, regex_options);
+ }
+ void add_selection_options(flatbuffers::Offset<SelectionModelOptions> selection_options) {
+ fbb_.AddOffset(Model::VT_SELECTION_OPTIONS, selection_options);
+ }
+ void add_classification_options(flatbuffers::Offset<ClassificationModelOptions> classification_options) {
+ fbb_.AddOffset(Model::VT_CLASSIFICATION_OPTIONS, classification_options);
+ }
+ void add_regex_model(flatbuffers::Offset<StructuredRegexModel> regex_model) {
+ fbb_.AddOffset(Model::VT_REGEX_MODEL, regex_model);
+ }
+ explicit ModelBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ModelBuilder &operator=(const ModelBuilder &);
+ flatbuffers::Offset<Model> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Model>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Model> CreateModel(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> language = 0,
+ int32_t version = 0,
+ flatbuffers::Offset<FeatureProcessorOptions> selection_feature_options = 0,
+ flatbuffers::Offset<FeatureProcessorOptions> classification_feature_options = 0,
+ flatbuffers::Offset<flatbuffers::Vector<uint8_t>> selection_model = 0,
+ flatbuffers::Offset<flatbuffers::Vector<uint8_t>> classification_model = 0,
+ flatbuffers::Offset<flatbuffers::Vector<uint8_t>> embedding_model = 0,
+ flatbuffers::Offset<RegexModelOptions> regex_options = 0,
+ flatbuffers::Offset<SelectionModelOptions> selection_options = 0,
+ flatbuffers::Offset<ClassificationModelOptions> classification_options = 0,
+ flatbuffers::Offset<StructuredRegexModel> regex_model = 0) {
+ ModelBuilder builder_(_fbb);
+ builder_.add_regex_model(regex_model);
+ builder_.add_classification_options(classification_options);
+ builder_.add_selection_options(selection_options);
+ builder_.add_regex_options(regex_options);
+ builder_.add_embedding_model(embedding_model);
+ builder_.add_classification_model(classification_model);
+ builder_.add_selection_model(selection_model);
+ builder_.add_classification_feature_options(classification_feature_options);
+ builder_.add_selection_feature_options(selection_feature_options);
+ builder_.add_version(version);
+ builder_.add_language(language);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<Model> CreateModelDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *language = nullptr,
+ int32_t version = 0,
+ flatbuffers::Offset<FeatureProcessorOptions> selection_feature_options = 0,
+ flatbuffers::Offset<FeatureProcessorOptions> classification_feature_options = 0,
+ const std::vector<uint8_t> *selection_model = nullptr,
+ const std::vector<uint8_t> *classification_model = nullptr,
+ const std::vector<uint8_t> *embedding_model = nullptr,
+ flatbuffers::Offset<RegexModelOptions> regex_options = 0,
+ flatbuffers::Offset<SelectionModelOptions> selection_options = 0,
+ flatbuffers::Offset<ClassificationModelOptions> classification_options = 0,
+ flatbuffers::Offset<StructuredRegexModel> regex_model = 0) {
+ return libtextclassifier2::CreateModel(
+ _fbb,
+ language ? _fbb.CreateString(language) : 0,
+ version,
+ selection_feature_options,
+ classification_feature_options,
+ selection_model ? _fbb.CreateVector<uint8_t>(*selection_model) : 0,
+ classification_model ? _fbb.CreateVector<uint8_t>(*classification_model) : 0,
+ embedding_model ? _fbb.CreateVector<uint8_t>(*embedding_model) : 0,
+ regex_options,
+ selection_options,
+ classification_options,
+ regex_model);
+}
+
+flatbuffers::Offset<Model> CreateModel(flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct TokenizationCodepointRangeT : public flatbuffers::NativeTable {
+ typedef TokenizationCodepointRange TableType;
+ int32_t start;
+ int32_t end;
+ libtextclassifier2::TokenizationCodepointRange_::Role role;
+ int32_t script_id;
+ TokenizationCodepointRangeT()
+ : start(0),
+ end(0),
+ role(libtextclassifier2::TokenizationCodepointRange_::Role_DEFAULT_ROLE),
+ script_id(0) {
+ }
+};
+
+struct TokenizationCodepointRange FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef TokenizationCodepointRangeT NativeTableType;
+ enum {
+ VT_START = 4,
+ VT_END = 6,
+ VT_ROLE = 8,
+ VT_SCRIPT_ID = 10
+ };
+ int32_t start() const {
+ return GetField<int32_t>(VT_START, 0);
+ }
+ int32_t end() const {
+ return GetField<int32_t>(VT_END, 0);
+ }
+ libtextclassifier2::TokenizationCodepointRange_::Role role() const {
+ return static_cast<libtextclassifier2::TokenizationCodepointRange_::Role>(GetField<int32_t>(VT_ROLE, 0));
+ }
+ int32_t script_id() const {
+ return GetField<int32_t>(VT_SCRIPT_ID, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_START) &&
+ VerifyField<int32_t>(verifier, VT_END) &&
+ VerifyField<int32_t>(verifier, VT_ROLE) &&
+ VerifyField<int32_t>(verifier, VT_SCRIPT_ID) &&
+ verifier.EndTable();
+ }
+ TokenizationCodepointRangeT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(TokenizationCodepointRangeT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<TokenizationCodepointRange> Pack(flatbuffers::FlatBufferBuilder &_fbb, const TokenizationCodepointRangeT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct TokenizationCodepointRangeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_start(int32_t start) {
+ fbb_.AddElement<int32_t>(TokenizationCodepointRange::VT_START, start, 0);
+ }
+ void add_end(int32_t end) {
+ fbb_.AddElement<int32_t>(TokenizationCodepointRange::VT_END, end, 0);
+ }
+ void add_role(libtextclassifier2::TokenizationCodepointRange_::Role role) {
+ fbb_.AddElement<int32_t>(TokenizationCodepointRange::VT_ROLE, static_cast<int32_t>(role), 0);
+ }
+ void add_script_id(int32_t script_id) {
+ fbb_.AddElement<int32_t>(TokenizationCodepointRange::VT_SCRIPT_ID, script_id, 0);
+ }
+ explicit TokenizationCodepointRangeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ TokenizationCodepointRangeBuilder &operator=(const TokenizationCodepointRangeBuilder &);
+ flatbuffers::Offset<TokenizationCodepointRange> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<TokenizationCodepointRange>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<TokenizationCodepointRange> CreateTokenizationCodepointRange(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t start = 0,
+ int32_t end = 0,
+ libtextclassifier2::TokenizationCodepointRange_::Role role = libtextclassifier2::TokenizationCodepointRange_::Role_DEFAULT_ROLE,
+ int32_t script_id = 0) {
+ TokenizationCodepointRangeBuilder builder_(_fbb);
+ builder_.add_script_id(script_id);
+ builder_.add_role(role);
+ builder_.add_end(end);
+ builder_.add_start(start);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<TokenizationCodepointRange> CreateTokenizationCodepointRange(flatbuffers::FlatBufferBuilder &_fbb, const TokenizationCodepointRangeT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct FeatureProcessorOptionsT : public flatbuffers::NativeTable {
+ typedef FeatureProcessorOptions TableType;
+ int32_t num_buckets;
+ int32_t embedding_size;
+ int32_t context_size;
+ int32_t max_selection_span;
+ std::vector<int32_t> chargram_orders;
+ int32_t max_word_length;
+ bool unicode_aware_features;
+ bool extract_case_feature;
+ bool extract_selection_mask_feature;
+ std::vector<std::string> regexp_feature;
+ bool remap_digits;
+ bool lowercase_tokens;
+ bool selection_reduced_output_space;
+ std::vector<std::string> collections;
+ int32_t default_collection;
+ bool only_use_line_with_click;
+ bool split_tokens_on_selection_boundaries;
+ std::vector<std::unique_ptr<TokenizationCodepointRangeT>> tokenization_codepoint_config;
+ libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod center_token_selection_method;
+ bool snap_label_span_boundaries_to_containing_tokens;
+ std::vector<std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::CodepointRangeT>> supported_codepoint_ranges;
+ std::vector<std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::CodepointRangeT>> internal_tokenizer_codepoint_ranges;
+ float min_supported_codepoint_ratio;
+ int32_t feature_version;
+ libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type;
+ bool icu_preserve_whitespace_tokens;
+ std::vector<int32_t> ignored_span_boundary_codepoints;
+ bool click_random_token_in_selection;
+ std::vector<std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::CollectionMapEntryT>> alternative_collection_map;
+ std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeaturesT> bounds_sensitive_features;
+ bool split_selection_candidates;
+ std::vector<std::string> allowed_chargrams;
+ bool tokenize_on_script_change;
+ FeatureProcessorOptionsT()
+ : num_buckets(-1),
+ embedding_size(-1),
+ context_size(-1),
+ max_selection_span(-1),
+ max_word_length(20),
+ unicode_aware_features(false),
+ extract_case_feature(false),
+ extract_selection_mask_feature(false),
+ remap_digits(false),
+ lowercase_tokens(false),
+ selection_reduced_output_space(false),
+ default_collection(-1),
+ only_use_line_with_click(false),
+ split_tokens_on_selection_boundaries(false),
+ center_token_selection_method(libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD),
+ snap_label_span_boundaries_to_containing_tokens(false),
+ min_supported_codepoint_ratio(0.0f),
+ feature_version(0),
+ tokenization_type(libtextclassifier2::FeatureProcessorOptions_::TokenizationType_INVALID_TOKENIZATION_TYPE),
+ icu_preserve_whitespace_tokens(false),
+ click_random_token_in_selection(false),
+ split_selection_candidates(false),
+ tokenize_on_script_change(false) {
+ }
+};
+
+struct FeatureProcessorOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef FeatureProcessorOptionsT NativeTableType;
+ enum {
+ VT_NUM_BUCKETS = 4,
+ VT_EMBEDDING_SIZE = 6,
+ VT_CONTEXT_SIZE = 8,
+ VT_MAX_SELECTION_SPAN = 10,
+ VT_CHARGRAM_ORDERS = 12,
+ VT_MAX_WORD_LENGTH = 14,
+ VT_UNICODE_AWARE_FEATURES = 16,
+ VT_EXTRACT_CASE_FEATURE = 18,
+ VT_EXTRACT_SELECTION_MASK_FEATURE = 20,
+ VT_REGEXP_FEATURE = 22,
+ VT_REMAP_DIGITS = 24,
+ VT_LOWERCASE_TOKENS = 26,
+ VT_SELECTION_REDUCED_OUTPUT_SPACE = 28,
+ VT_COLLECTIONS = 30,
+ VT_DEFAULT_COLLECTION = 32,
+ VT_ONLY_USE_LINE_WITH_CLICK = 34,
+ VT_SPLIT_TOKENS_ON_SELECTION_BOUNDARIES = 36,
+ VT_TOKENIZATION_CODEPOINT_CONFIG = 38,
+ VT_CENTER_TOKEN_SELECTION_METHOD = 40,
+ VT_SNAP_LABEL_SPAN_BOUNDARIES_TO_CONTAINING_TOKENS = 42,
+ VT_SUPPORTED_CODEPOINT_RANGES = 44,
+ VT_INTERNAL_TOKENIZER_CODEPOINT_RANGES = 46,
+ VT_MIN_SUPPORTED_CODEPOINT_RATIO = 48,
+ VT_FEATURE_VERSION = 50,
+ VT_TOKENIZATION_TYPE = 52,
+ VT_ICU_PRESERVE_WHITESPACE_TOKENS = 54,
+ VT_IGNORED_SPAN_BOUNDARY_CODEPOINTS = 56,
+ VT_CLICK_RANDOM_TOKEN_IN_SELECTION = 58,
+ VT_ALTERNATIVE_COLLECTION_MAP = 60,
+ VT_BOUNDS_SENSITIVE_FEATURES = 62,
+ VT_SPLIT_SELECTION_CANDIDATES = 64,
+ VT_ALLOWED_CHARGRAMS = 66,
+ VT_TOKENIZE_ON_SCRIPT_CHANGE = 68
+ };
+ int32_t num_buckets() const {
+ return GetField<int32_t>(VT_NUM_BUCKETS, -1);
+ }
+ int32_t embedding_size() const {
+ return GetField<int32_t>(VT_EMBEDDING_SIZE, -1);
+ }
+ int32_t context_size() const {
+ return GetField<int32_t>(VT_CONTEXT_SIZE, -1);
+ }
+ int32_t max_selection_span() const {
+ return GetField<int32_t>(VT_MAX_SELECTION_SPAN, -1);
+ }
+ const flatbuffers::Vector<int32_t> *chargram_orders() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_CHARGRAM_ORDERS);
+ }
+ int32_t max_word_length() const {
+ return GetField<int32_t>(VT_MAX_WORD_LENGTH, 20);
+ }
+ bool unicode_aware_features() const {
+ return GetField<uint8_t>(VT_UNICODE_AWARE_FEATURES, 0) != 0;
+ }
+ bool extract_case_feature() const {
+ return GetField<uint8_t>(VT_EXTRACT_CASE_FEATURE, 0) != 0;
+ }
+ bool extract_selection_mask_feature() const {
+ return GetField<uint8_t>(VT_EXTRACT_SELECTION_MASK_FEATURE, 0) != 0;
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *regexp_feature() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_REGEXP_FEATURE);
+ }
+ bool remap_digits() const {
+ return GetField<uint8_t>(VT_REMAP_DIGITS, 0) != 0;
+ }
+ bool lowercase_tokens() const {
+ return GetField<uint8_t>(VT_LOWERCASE_TOKENS, 0) != 0;
+ }
+ bool selection_reduced_output_space() const {
+ return GetField<uint8_t>(VT_SELECTION_REDUCED_OUTPUT_SPACE, 0) != 0;
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *collections() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_COLLECTIONS);
+ }
+ int32_t default_collection() const {
+ return GetField<int32_t>(VT_DEFAULT_COLLECTION, -1);
+ }
+ bool only_use_line_with_click() const {
+ return GetField<uint8_t>(VT_ONLY_USE_LINE_WITH_CLICK, 0) != 0;
+ }
+ bool split_tokens_on_selection_boundaries() const {
+ return GetField<uint8_t>(VT_SPLIT_TOKENS_ON_SELECTION_BOUNDARIES, 0) != 0;
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<TokenizationCodepointRange>> *tokenization_codepoint_config() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<TokenizationCodepointRange>> *>(VT_TOKENIZATION_CODEPOINT_CONFIG);
+ }
+ libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod center_token_selection_method() const {
+ return static_cast<libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod>(GetField<int32_t>(VT_CENTER_TOKEN_SELECTION_METHOD, 0));
+ }
+ bool snap_label_span_boundaries_to_containing_tokens() const {
+ return GetField<uint8_t>(VT_SNAP_LABEL_SPAN_BOUNDARIES_TO_CONTAINING_TOKENS, 0) != 0;
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *supported_codepoint_ranges() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *>(VT_SUPPORTED_CODEPOINT_RANGES);
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *internal_tokenizer_codepoint_ranges() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *>(VT_INTERNAL_TOKENIZER_CODEPOINT_RANGES);
+ }
+ float min_supported_codepoint_ratio() const {
+ return GetField<float>(VT_MIN_SUPPORTED_CODEPOINT_RATIO, 0.0f);
+ }
+ int32_t feature_version() const {
+ return GetField<int32_t>(VT_FEATURE_VERSION, 0);
+ }
+ libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type() const {
+ return static_cast<libtextclassifier2::FeatureProcessorOptions_::TokenizationType>(GetField<int32_t>(VT_TOKENIZATION_TYPE, 0));
+ }
+ bool icu_preserve_whitespace_tokens() const {
+ return GetField<uint8_t>(VT_ICU_PRESERVE_WHITESPACE_TOKENS, 0) != 0;
+ }
+ const flatbuffers::Vector<int32_t> *ignored_span_boundary_codepoints() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_IGNORED_SPAN_BOUNDARY_CODEPOINTS);
+ }
+ bool click_random_token_in_selection() const {
+ return GetField<uint8_t>(VT_CLICK_RANDOM_TOKEN_IN_SELECTION, 0) != 0;
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CollectionMapEntry>> *alternative_collection_map() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CollectionMapEntry>> *>(VT_ALTERNATIVE_COLLECTION_MAP);
+ }
+ const libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeatures *bounds_sensitive_features() const {
+ return GetPointer<const libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeatures *>(VT_BOUNDS_SENSITIVE_FEATURES);
+ }
+ bool split_selection_candidates() const {
+ return GetField<uint8_t>(VT_SPLIT_SELECTION_CANDIDATES, 0) != 0;
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *allowed_chargrams() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_ALLOWED_CHARGRAMS);
+ }
+ bool tokenize_on_script_change() const {
+ return GetField<uint8_t>(VT_TOKENIZE_ON_SCRIPT_CHANGE, 0) != 0;
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_NUM_BUCKETS) &&
+ VerifyField<int32_t>(verifier, VT_EMBEDDING_SIZE) &&
+ VerifyField<int32_t>(verifier, VT_CONTEXT_SIZE) &&
+ VerifyField<int32_t>(verifier, VT_MAX_SELECTION_SPAN) &&
+ VerifyOffset(verifier, VT_CHARGRAM_ORDERS) &&
+ verifier.Verify(chargram_orders()) &&
+ VerifyField<int32_t>(verifier, VT_MAX_WORD_LENGTH) &&
+ VerifyField<uint8_t>(verifier, VT_UNICODE_AWARE_FEATURES) &&
+ VerifyField<uint8_t>(verifier, VT_EXTRACT_CASE_FEATURE) &&
+ VerifyField<uint8_t>(verifier, VT_EXTRACT_SELECTION_MASK_FEATURE) &&
+ VerifyOffset(verifier, VT_REGEXP_FEATURE) &&
+ verifier.Verify(regexp_feature()) &&
+ verifier.VerifyVectorOfStrings(regexp_feature()) &&
+ VerifyField<uint8_t>(verifier, VT_REMAP_DIGITS) &&
+ VerifyField<uint8_t>(verifier, VT_LOWERCASE_TOKENS) &&
+ VerifyField<uint8_t>(verifier, VT_SELECTION_REDUCED_OUTPUT_SPACE) &&
+ VerifyOffset(verifier, VT_COLLECTIONS) &&
+ verifier.Verify(collections()) &&
+ verifier.VerifyVectorOfStrings(collections()) &&
+ VerifyField<int32_t>(verifier, VT_DEFAULT_COLLECTION) &&
+ VerifyField<uint8_t>(verifier, VT_ONLY_USE_LINE_WITH_CLICK) &&
+ VerifyField<uint8_t>(verifier, VT_SPLIT_TOKENS_ON_SELECTION_BOUNDARIES) &&
+ VerifyOffset(verifier, VT_TOKENIZATION_CODEPOINT_CONFIG) &&
+ verifier.Verify(tokenization_codepoint_config()) &&
+ verifier.VerifyVectorOfTables(tokenization_codepoint_config()) &&
+ VerifyField<int32_t>(verifier, VT_CENTER_TOKEN_SELECTION_METHOD) &&
+ VerifyField<uint8_t>(verifier, VT_SNAP_LABEL_SPAN_BOUNDARIES_TO_CONTAINING_TOKENS) &&
+ VerifyOffset(verifier, VT_SUPPORTED_CODEPOINT_RANGES) &&
+ verifier.Verify(supported_codepoint_ranges()) &&
+ verifier.VerifyVectorOfTables(supported_codepoint_ranges()) &&
+ VerifyOffset(verifier, VT_INTERNAL_TOKENIZER_CODEPOINT_RANGES) &&
+ verifier.Verify(internal_tokenizer_codepoint_ranges()) &&
+ verifier.VerifyVectorOfTables(internal_tokenizer_codepoint_ranges()) &&
+ VerifyField<float>(verifier, VT_MIN_SUPPORTED_CODEPOINT_RATIO) &&
+ VerifyField<int32_t>(verifier, VT_FEATURE_VERSION) &&
+ VerifyField<int32_t>(verifier, VT_TOKENIZATION_TYPE) &&
+ VerifyField<uint8_t>(verifier, VT_ICU_PRESERVE_WHITESPACE_TOKENS) &&
+ VerifyOffset(verifier, VT_IGNORED_SPAN_BOUNDARY_CODEPOINTS) &&
+ verifier.Verify(ignored_span_boundary_codepoints()) &&
+ VerifyField<uint8_t>(verifier, VT_CLICK_RANDOM_TOKEN_IN_SELECTION) &&
+ VerifyOffset(verifier, VT_ALTERNATIVE_COLLECTION_MAP) &&
+ verifier.Verify(alternative_collection_map()) &&
+ verifier.VerifyVectorOfTables(alternative_collection_map()) &&
+ VerifyOffset(verifier, VT_BOUNDS_SENSITIVE_FEATURES) &&
+ verifier.VerifyTable(bounds_sensitive_features()) &&
+ VerifyField<uint8_t>(verifier, VT_SPLIT_SELECTION_CANDIDATES) &&
+ VerifyOffset(verifier, VT_ALLOWED_CHARGRAMS) &&
+ verifier.Verify(allowed_chargrams()) &&
+ verifier.VerifyVectorOfStrings(allowed_chargrams()) &&
+ VerifyField<uint8_t>(verifier, VT_TOKENIZE_ON_SCRIPT_CHANGE) &&
+ verifier.EndTable();
+ }
+ FeatureProcessorOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(FeatureProcessorOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<FeatureProcessorOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const FeatureProcessorOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct FeatureProcessorOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_num_buckets(int32_t num_buckets) {
+ fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_NUM_BUCKETS, num_buckets, -1);
+ }
+ void add_embedding_size(int32_t embedding_size) {
+ fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_EMBEDDING_SIZE, embedding_size, -1);
+ }
+ void add_context_size(int32_t context_size) {
+ fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_CONTEXT_SIZE, context_size, -1);
+ }
+ void add_max_selection_span(int32_t max_selection_span) {
+ fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_MAX_SELECTION_SPAN, max_selection_span, -1);
+ }
+ void add_chargram_orders(flatbuffers::Offset<flatbuffers::Vector<int32_t>> chargram_orders) {
+ fbb_.AddOffset(FeatureProcessorOptions::VT_CHARGRAM_ORDERS, chargram_orders);
+ }
+ void add_max_word_length(int32_t max_word_length) {
+ fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_MAX_WORD_LENGTH, max_word_length, 20);
+ }
+ void add_unicode_aware_features(bool unicode_aware_features) {
+ fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_UNICODE_AWARE_FEATURES, static_cast<uint8_t>(unicode_aware_features), 0);
+ }
+ void add_extract_case_feature(bool extract_case_feature) {
+ fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_EXTRACT_CASE_FEATURE, static_cast<uint8_t>(extract_case_feature), 0);
+ }
+ void add_extract_selection_mask_feature(bool extract_selection_mask_feature) {
+ fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_EXTRACT_SELECTION_MASK_FEATURE, static_cast<uint8_t>(extract_selection_mask_feature), 0);
+ }
+ void add_regexp_feature(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> regexp_feature) {
+ fbb_.AddOffset(FeatureProcessorOptions::VT_REGEXP_FEATURE, regexp_feature);
+ }
+ void add_remap_digits(bool remap_digits) {
+ fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_REMAP_DIGITS, static_cast<uint8_t>(remap_digits), 0);
+ }
+ void add_lowercase_tokens(bool lowercase_tokens) {
+ fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_LOWERCASE_TOKENS, static_cast<uint8_t>(lowercase_tokens), 0);
+ }
+ void add_selection_reduced_output_space(bool selection_reduced_output_space) {
+ fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_SELECTION_REDUCED_OUTPUT_SPACE, static_cast<uint8_t>(selection_reduced_output_space), 0);
+ }
+ void add_collections(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> collections) {
+ fbb_.AddOffset(FeatureProcessorOptions::VT_COLLECTIONS, collections);
+ }
+ void add_default_collection(int32_t default_collection) {
+ fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_DEFAULT_COLLECTION, default_collection, -1);
+ }
+ void add_only_use_line_with_click(bool only_use_line_with_click) {
+ fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_ONLY_USE_LINE_WITH_CLICK, static_cast<uint8_t>(only_use_line_with_click), 0);
+ }
+ void add_split_tokens_on_selection_boundaries(bool split_tokens_on_selection_boundaries) {
+ fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_SPLIT_TOKENS_ON_SELECTION_BOUNDARIES, static_cast<uint8_t>(split_tokens_on_selection_boundaries), 0);
+ }
+ void add_tokenization_codepoint_config(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<TokenizationCodepointRange>>> tokenization_codepoint_config) {
+ fbb_.AddOffset(FeatureProcessorOptions::VT_TOKENIZATION_CODEPOINT_CONFIG, tokenization_codepoint_config);
+ }
+ void add_center_token_selection_method(libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod center_token_selection_method) {
+ fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_CENTER_TOKEN_SELECTION_METHOD, static_cast<int32_t>(center_token_selection_method), 0);
+ }
+ void add_snap_label_span_boundaries_to_containing_tokens(bool snap_label_span_boundaries_to_containing_tokens) {
+ fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_SNAP_LABEL_SPAN_BOUNDARIES_TO_CONTAINING_TOKENS, static_cast<uint8_t>(snap_label_span_boundaries_to_containing_tokens), 0);
+ }
+ void add_supported_codepoint_ranges(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>> supported_codepoint_ranges) {
+ fbb_.AddOffset(FeatureProcessorOptions::VT_SUPPORTED_CODEPOINT_RANGES, supported_codepoint_ranges);
+ }
+ void add_internal_tokenizer_codepoint_ranges(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>> internal_tokenizer_codepoint_ranges) {
+ fbb_.AddOffset(FeatureProcessorOptions::VT_INTERNAL_TOKENIZER_CODEPOINT_RANGES, internal_tokenizer_codepoint_ranges);
+ }
+ void add_min_supported_codepoint_ratio(float min_supported_codepoint_ratio) {
+ fbb_.AddElement<float>(FeatureProcessorOptions::VT_MIN_SUPPORTED_CODEPOINT_RATIO, min_supported_codepoint_ratio, 0.0f);
+ }
+ void add_feature_version(int32_t feature_version) {
+ fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_FEATURE_VERSION, feature_version, 0);
+ }
+ void add_tokenization_type(libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type) {
+ fbb_.AddElement<int32_t>(FeatureProcessorOptions::VT_TOKENIZATION_TYPE, static_cast<int32_t>(tokenization_type), 0);
+ }
+ void add_icu_preserve_whitespace_tokens(bool icu_preserve_whitespace_tokens) {
+ fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_ICU_PRESERVE_WHITESPACE_TOKENS, static_cast<uint8_t>(icu_preserve_whitespace_tokens), 0);
+ }
+ void add_ignored_span_boundary_codepoints(flatbuffers::Offset<flatbuffers::Vector<int32_t>> ignored_span_boundary_codepoints) {
+ fbb_.AddOffset(FeatureProcessorOptions::VT_IGNORED_SPAN_BOUNDARY_CODEPOINTS, ignored_span_boundary_codepoints);
+ }
+ void add_click_random_token_in_selection(bool click_random_token_in_selection) {
+ fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_CLICK_RANDOM_TOKEN_IN_SELECTION, static_cast<uint8_t>(click_random_token_in_selection), 0);
+ }
+ void add_alternative_collection_map(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CollectionMapEntry>>> alternative_collection_map) {
+ fbb_.AddOffset(FeatureProcessorOptions::VT_ALTERNATIVE_COLLECTION_MAP, alternative_collection_map);
+ }
+ void add_bounds_sensitive_features(flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeatures> bounds_sensitive_features) {
+ fbb_.AddOffset(FeatureProcessorOptions::VT_BOUNDS_SENSITIVE_FEATURES, bounds_sensitive_features);
+ }
+ void add_split_selection_candidates(bool split_selection_candidates) {
+ fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_SPLIT_SELECTION_CANDIDATES, static_cast<uint8_t>(split_selection_candidates), 0);
+ }
+ void add_allowed_chargrams(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> allowed_chargrams) {
+ fbb_.AddOffset(FeatureProcessorOptions::VT_ALLOWED_CHARGRAMS, allowed_chargrams);
+ }
+ void add_tokenize_on_script_change(bool tokenize_on_script_change) {
+ fbb_.AddElement<uint8_t>(FeatureProcessorOptions::VT_TOKENIZE_ON_SCRIPT_CHANGE, static_cast<uint8_t>(tokenize_on_script_change), 0);
+ }
+ explicit FeatureProcessorOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ FeatureProcessorOptionsBuilder &operator=(const FeatureProcessorOptionsBuilder &);
+ flatbuffers::Offset<FeatureProcessorOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<FeatureProcessorOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<FeatureProcessorOptions> CreateFeatureProcessorOptions(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t num_buckets = -1,
+ int32_t embedding_size = -1,
+ int32_t context_size = -1,
+ int32_t max_selection_span = -1,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> chargram_orders = 0,
+ int32_t max_word_length = 20,
+ bool unicode_aware_features = false,
+ bool extract_case_feature = false,
+ bool extract_selection_mask_feature = false,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> regexp_feature = 0,
+ bool remap_digits = false,
+ bool lowercase_tokens = false,
+ bool selection_reduced_output_space = false,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> collections = 0,
+ int32_t default_collection = -1,
+ bool only_use_line_with_click = false,
+ bool split_tokens_on_selection_boundaries = false,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<TokenizationCodepointRange>>> tokenization_codepoint_config = 0,
+ libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod center_token_selection_method = libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD,
+ bool snap_label_span_boundaries_to_containing_tokens = false,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>> supported_codepoint_ranges = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>> internal_tokenizer_codepoint_ranges = 0,
+ float min_supported_codepoint_ratio = 0.0f,
+ int32_t feature_version = 0,
+ libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type = libtextclassifier2::FeatureProcessorOptions_::TokenizationType_INVALID_TOKENIZATION_TYPE,
+ bool icu_preserve_whitespace_tokens = false,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> ignored_span_boundary_codepoints = 0,
+ bool click_random_token_in_selection = false,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CollectionMapEntry>>> alternative_collection_map = 0,
+ flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeatures> bounds_sensitive_features = 0,
+ bool split_selection_candidates = false,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> allowed_chargrams = 0,
+ bool tokenize_on_script_change = false) {
+ FeatureProcessorOptionsBuilder builder_(_fbb);
+ builder_.add_allowed_chargrams(allowed_chargrams);
+ builder_.add_bounds_sensitive_features(bounds_sensitive_features);
+ builder_.add_alternative_collection_map(alternative_collection_map);
+ builder_.add_ignored_span_boundary_codepoints(ignored_span_boundary_codepoints);
+ builder_.add_tokenization_type(tokenization_type);
+ builder_.add_feature_version(feature_version);
+ builder_.add_min_supported_codepoint_ratio(min_supported_codepoint_ratio);
+ builder_.add_internal_tokenizer_codepoint_ranges(internal_tokenizer_codepoint_ranges);
+ builder_.add_supported_codepoint_ranges(supported_codepoint_ranges);
+ builder_.add_center_token_selection_method(center_token_selection_method);
+ builder_.add_tokenization_codepoint_config(tokenization_codepoint_config);
+ builder_.add_default_collection(default_collection);
+ builder_.add_collections(collections);
+ builder_.add_regexp_feature(regexp_feature);
+ builder_.add_max_word_length(max_word_length);
+ builder_.add_chargram_orders(chargram_orders);
+ builder_.add_max_selection_span(max_selection_span);
+ builder_.add_context_size(context_size);
+ builder_.add_embedding_size(embedding_size);
+ builder_.add_num_buckets(num_buckets);
+ builder_.add_tokenize_on_script_change(tokenize_on_script_change);
+ builder_.add_split_selection_candidates(split_selection_candidates);
+ builder_.add_click_random_token_in_selection(click_random_token_in_selection);
+ builder_.add_icu_preserve_whitespace_tokens(icu_preserve_whitespace_tokens);
+ builder_.add_snap_label_span_boundaries_to_containing_tokens(snap_label_span_boundaries_to_containing_tokens);
+ builder_.add_split_tokens_on_selection_boundaries(split_tokens_on_selection_boundaries);
+ builder_.add_only_use_line_with_click(only_use_line_with_click);
+ builder_.add_selection_reduced_output_space(selection_reduced_output_space);
+ builder_.add_lowercase_tokens(lowercase_tokens);
+ builder_.add_remap_digits(remap_digits);
+ builder_.add_extract_selection_mask_feature(extract_selection_mask_feature);
+ builder_.add_extract_case_feature(extract_case_feature);
+ builder_.add_unicode_aware_features(unicode_aware_features);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<FeatureProcessorOptions> CreateFeatureProcessorOptionsDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t num_buckets = -1,
+ int32_t embedding_size = -1,
+ int32_t context_size = -1,
+ int32_t max_selection_span = -1,
+ const std::vector<int32_t> *chargram_orders = nullptr,
+ int32_t max_word_length = 20,
+ bool unicode_aware_features = false,
+ bool extract_case_feature = false,
+ bool extract_selection_mask_feature = false,
+ const std::vector<flatbuffers::Offset<flatbuffers::String>> *regexp_feature = nullptr,
+ bool remap_digits = false,
+ bool lowercase_tokens = false,
+ bool selection_reduced_output_space = false,
+ const std::vector<flatbuffers::Offset<flatbuffers::String>> *collections = nullptr,
+ int32_t default_collection = -1,
+ bool only_use_line_with_click = false,
+ bool split_tokens_on_selection_boundaries = false,
+ const std::vector<flatbuffers::Offset<TokenizationCodepointRange>> *tokenization_codepoint_config = nullptr,
+ libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod center_token_selection_method = libtextclassifier2::FeatureProcessorOptions_::CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD,
+ bool snap_label_span_boundaries_to_containing_tokens = false,
+ const std::vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *supported_codepoint_ranges = nullptr,
+ const std::vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> *internal_tokenizer_codepoint_ranges = nullptr,
+ float min_supported_codepoint_ratio = 0.0f,
+ int32_t feature_version = 0,
+ libtextclassifier2::FeatureProcessorOptions_::TokenizationType tokenization_type = libtextclassifier2::FeatureProcessorOptions_::TokenizationType_INVALID_TOKENIZATION_TYPE,
+ bool icu_preserve_whitespace_tokens = false,
+ const std::vector<int32_t> *ignored_span_boundary_codepoints = nullptr,
+ bool click_random_token_in_selection = false,
+ const std::vector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CollectionMapEntry>> *alternative_collection_map = nullptr,
+ flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeatures> bounds_sensitive_features = 0,
+ bool split_selection_candidates = false,
+ const std::vector<flatbuffers::Offset<flatbuffers::String>> *allowed_chargrams = nullptr,
+ bool tokenize_on_script_change = false) {
+ return libtextclassifier2::CreateFeatureProcessorOptions(
+ _fbb,
+ num_buckets,
+ embedding_size,
+ context_size,
+ max_selection_span,
+ chargram_orders ? _fbb.CreateVector<int32_t>(*chargram_orders) : 0,
+ max_word_length,
+ unicode_aware_features,
+ extract_case_feature,
+ extract_selection_mask_feature,
+ regexp_feature ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*regexp_feature) : 0,
+ remap_digits,
+ lowercase_tokens,
+ selection_reduced_output_space,
+ collections ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*collections) : 0,
+ default_collection,
+ only_use_line_with_click,
+ split_tokens_on_selection_boundaries,
+ tokenization_codepoint_config ? _fbb.CreateVector<flatbuffers::Offset<TokenizationCodepointRange>>(*tokenization_codepoint_config) : 0,
+ center_token_selection_method,
+ snap_label_span_boundaries_to_containing_tokens,
+ supported_codepoint_ranges ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>(*supported_codepoint_ranges) : 0,
+ internal_tokenizer_codepoint_ranges ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>>(*internal_tokenizer_codepoint_ranges) : 0,
+ min_supported_codepoint_ratio,
+ feature_version,
+ tokenization_type,
+ icu_preserve_whitespace_tokens,
+ ignored_span_boundary_codepoints ? _fbb.CreateVector<int32_t>(*ignored_span_boundary_codepoints) : 0,
+ click_random_token_in_selection,
+ alternative_collection_map ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CollectionMapEntry>>(*alternative_collection_map) : 0,
+ bounds_sensitive_features,
+ split_selection_candidates,
+ allowed_chargrams ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*allowed_chargrams) : 0,
+ tokenize_on_script_change);
+}
+
+flatbuffers::Offset<FeatureProcessorOptions> CreateFeatureProcessorOptions(flatbuffers::FlatBufferBuilder &_fbb, const FeatureProcessorOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+namespace FeatureProcessorOptions_ {
+
+struct CodepointRangeT : public flatbuffers::NativeTable {
+ typedef CodepointRange TableType;
+ int32_t start;
+ int32_t end;
+ CodepointRangeT()
+ : start(0),
+ end(0) {
+ }
+};
+
+struct CodepointRange FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef CodepointRangeT NativeTableType;
+ enum {
+ VT_START = 4,
+ VT_END = 6
+ };
+ int32_t start() const {
+ return GetField<int32_t>(VT_START, 0);
+ }
+ int32_t end() const {
+ return GetField<int32_t>(VT_END, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_START) &&
+ VerifyField<int32_t>(verifier, VT_END) &&
+ verifier.EndTable();
+ }
+ CodepointRangeT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(CodepointRangeT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<CodepointRange> Pack(flatbuffers::FlatBufferBuilder &_fbb, const CodepointRangeT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct CodepointRangeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_start(int32_t start) {
+ fbb_.AddElement<int32_t>(CodepointRange::VT_START, start, 0);
+ }
+ void add_end(int32_t end) {
+ fbb_.AddElement<int32_t>(CodepointRange::VT_END, end, 0);
+ }
+ explicit CodepointRangeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ CodepointRangeBuilder &operator=(const CodepointRangeBuilder &);
+ flatbuffers::Offset<CodepointRange> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<CodepointRange>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<CodepointRange> CreateCodepointRange(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t start = 0,
+ int32_t end = 0) {
+ CodepointRangeBuilder builder_(_fbb);
+ builder_.add_end(end);
+ builder_.add_start(start);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<CodepointRange> CreateCodepointRange(flatbuffers::FlatBufferBuilder &_fbb, const CodepointRangeT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct CollectionMapEntryT : public flatbuffers::NativeTable {
+ typedef CollectionMapEntry TableType;
+ std::string key;
+ std::string value;
+ CollectionMapEntryT() {
+ }
+};
+
+struct CollectionMapEntry FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef CollectionMapEntryT NativeTableType;
+ enum {
+ VT_KEY = 4,
+ VT_VALUE = 6
+ };
+ const flatbuffers::String *key() const {
+ return GetPointer<const flatbuffers::String *>(VT_KEY);
+ }
+ const flatbuffers::String *value() const {
+ return GetPointer<const flatbuffers::String *>(VT_VALUE);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_KEY) &&
+ verifier.Verify(key()) &&
+ VerifyOffset(verifier, VT_VALUE) &&
+ verifier.Verify(value()) &&
+ verifier.EndTable();
+ }
+ CollectionMapEntryT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(CollectionMapEntryT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<CollectionMapEntry> Pack(flatbuffers::FlatBufferBuilder &_fbb, const CollectionMapEntryT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct CollectionMapEntryBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_key(flatbuffers::Offset<flatbuffers::String> key) {
+ fbb_.AddOffset(CollectionMapEntry::VT_KEY, key);
+ }
+ void add_value(flatbuffers::Offset<flatbuffers::String> value) {
+ fbb_.AddOffset(CollectionMapEntry::VT_VALUE, value);
+ }
+ explicit CollectionMapEntryBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ CollectionMapEntryBuilder &operator=(const CollectionMapEntryBuilder &);
+ flatbuffers::Offset<CollectionMapEntry> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<CollectionMapEntry>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<CollectionMapEntry> CreateCollectionMapEntry(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> key = 0,
+ flatbuffers::Offset<flatbuffers::String> value = 0) {
+ CollectionMapEntryBuilder builder_(_fbb);
+ builder_.add_value(value);
+ builder_.add_key(key);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<CollectionMapEntry> CreateCollectionMapEntryDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *key = nullptr,
+ const char *value = nullptr) {
+ return libtextclassifier2::FeatureProcessorOptions_::CreateCollectionMapEntry(
+ _fbb,
+ key ? _fbb.CreateString(key) : 0,
+ value ? _fbb.CreateString(value) : 0);
+}
+
+flatbuffers::Offset<CollectionMapEntry> CreateCollectionMapEntry(flatbuffers::FlatBufferBuilder &_fbb, const CollectionMapEntryT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct BoundsSensitiveFeaturesT : public flatbuffers::NativeTable {
+ typedef BoundsSensitiveFeatures TableType;
+ bool enabled;
+ int32_t num_tokens_before;
+ int32_t num_tokens_inside_left;
+ int32_t num_tokens_inside_right;
+ int32_t num_tokens_after;
+ bool include_inside_bag;
+ bool include_inside_length;
+ BoundsSensitiveFeaturesT()
+ : enabled(false),
+ num_tokens_before(0),
+ num_tokens_inside_left(0),
+ num_tokens_inside_right(0),
+ num_tokens_after(0),
+ include_inside_bag(false),
+ include_inside_length(false) {
+ }
+};
+
+struct BoundsSensitiveFeatures FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef BoundsSensitiveFeaturesT NativeTableType;
+ enum {
+ VT_ENABLED = 4,
+ VT_NUM_TOKENS_BEFORE = 6,
+ VT_NUM_TOKENS_INSIDE_LEFT = 8,
+ VT_NUM_TOKENS_INSIDE_RIGHT = 10,
+ VT_NUM_TOKENS_AFTER = 12,
+ VT_INCLUDE_INSIDE_BAG = 14,
+ VT_INCLUDE_INSIDE_LENGTH = 16
+ };
+ bool enabled() const {
+ return GetField<uint8_t>(VT_ENABLED, 0) != 0;
+ }
+ int32_t num_tokens_before() const {
+ return GetField<int32_t>(VT_NUM_TOKENS_BEFORE, 0);
+ }
+ int32_t num_tokens_inside_left() const {
+ return GetField<int32_t>(VT_NUM_TOKENS_INSIDE_LEFT, 0);
+ }
+ int32_t num_tokens_inside_right() const {
+ return GetField<int32_t>(VT_NUM_TOKENS_INSIDE_RIGHT, 0);
+ }
+ int32_t num_tokens_after() const {
+ return GetField<int32_t>(VT_NUM_TOKENS_AFTER, 0);
+ }
+ bool include_inside_bag() const {
+ return GetField<uint8_t>(VT_INCLUDE_INSIDE_BAG, 0) != 0;
+ }
+ bool include_inside_length() const {
+ return GetField<uint8_t>(VT_INCLUDE_INSIDE_LENGTH, 0) != 0;
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint8_t>(verifier, VT_ENABLED) &&
+ VerifyField<int32_t>(verifier, VT_NUM_TOKENS_BEFORE) &&
+ VerifyField<int32_t>(verifier, VT_NUM_TOKENS_INSIDE_LEFT) &&
+ VerifyField<int32_t>(verifier, VT_NUM_TOKENS_INSIDE_RIGHT) &&
+ VerifyField<int32_t>(verifier, VT_NUM_TOKENS_AFTER) &&
+ VerifyField<uint8_t>(verifier, VT_INCLUDE_INSIDE_BAG) &&
+ VerifyField<uint8_t>(verifier, VT_INCLUDE_INSIDE_LENGTH) &&
+ verifier.EndTable();
+ }
+ BoundsSensitiveFeaturesT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(BoundsSensitiveFeaturesT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<BoundsSensitiveFeatures> Pack(flatbuffers::FlatBufferBuilder &_fbb, const BoundsSensitiveFeaturesT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct BoundsSensitiveFeaturesBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_enabled(bool enabled) {
+ fbb_.AddElement<uint8_t>(BoundsSensitiveFeatures::VT_ENABLED, static_cast<uint8_t>(enabled), 0);
+ }
+ void add_num_tokens_before(int32_t num_tokens_before) {
+ fbb_.AddElement<int32_t>(BoundsSensitiveFeatures::VT_NUM_TOKENS_BEFORE, num_tokens_before, 0);
+ }
+ void add_num_tokens_inside_left(int32_t num_tokens_inside_left) {
+ fbb_.AddElement<int32_t>(BoundsSensitiveFeatures::VT_NUM_TOKENS_INSIDE_LEFT, num_tokens_inside_left, 0);
+ }
+ void add_num_tokens_inside_right(int32_t num_tokens_inside_right) {
+ fbb_.AddElement<int32_t>(BoundsSensitiveFeatures::VT_NUM_TOKENS_INSIDE_RIGHT, num_tokens_inside_right, 0);
+ }
+ void add_num_tokens_after(int32_t num_tokens_after) {
+ fbb_.AddElement<int32_t>(BoundsSensitiveFeatures::VT_NUM_TOKENS_AFTER, num_tokens_after, 0);
+ }
+ void add_include_inside_bag(bool include_inside_bag) {
+ fbb_.AddElement<uint8_t>(BoundsSensitiveFeatures::VT_INCLUDE_INSIDE_BAG, static_cast<uint8_t>(include_inside_bag), 0);
+ }
+ void add_include_inside_length(bool include_inside_length) {
+ fbb_.AddElement<uint8_t>(BoundsSensitiveFeatures::VT_INCLUDE_INSIDE_LENGTH, static_cast<uint8_t>(include_inside_length), 0);
+ }
+ explicit BoundsSensitiveFeaturesBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ BoundsSensitiveFeaturesBuilder &operator=(const BoundsSensitiveFeaturesBuilder &);
+ flatbuffers::Offset<BoundsSensitiveFeatures> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<BoundsSensitiveFeatures>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<BoundsSensitiveFeatures> CreateBoundsSensitiveFeatures(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ bool enabled = false,
+ int32_t num_tokens_before = 0,
+ int32_t num_tokens_inside_left = 0,
+ int32_t num_tokens_inside_right = 0,
+ int32_t num_tokens_after = 0,
+ bool include_inside_bag = false,
+ bool include_inside_length = false) {
+ BoundsSensitiveFeaturesBuilder builder_(_fbb);
+ builder_.add_num_tokens_after(num_tokens_after);
+ builder_.add_num_tokens_inside_right(num_tokens_inside_right);
+ builder_.add_num_tokens_inside_left(num_tokens_inside_left);
+ builder_.add_num_tokens_before(num_tokens_before);
+ builder_.add_include_inside_length(include_inside_length);
+ builder_.add_include_inside_bag(include_inside_bag);
+ builder_.add_enabled(enabled);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<BoundsSensitiveFeatures> CreateBoundsSensitiveFeatures(flatbuffers::FlatBufferBuilder &_fbb, const BoundsSensitiveFeaturesT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+} // namespace FeatureProcessorOptions_
+
+inline SelectionModelOptionsT *SelectionModelOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new SelectionModelOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void SelectionModelOptions::UnPackTo(SelectionModelOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = strip_unpaired_brackets(); _o->strip_unpaired_brackets = _e; };
+ { auto _e = symmetry_context_size(); _o->symmetry_context_size = _e; };
+}
+
+inline flatbuffers::Offset<SelectionModelOptions> SelectionModelOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SelectionModelOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateSelectionModelOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<SelectionModelOptions> CreateSelectionModelOptions(flatbuffers::FlatBufferBuilder &_fbb, const SelectionModelOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SelectionModelOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _strip_unpaired_brackets = _o->strip_unpaired_brackets;
+ auto _symmetry_context_size = _o->symmetry_context_size;
+ return libtextclassifier2::CreateSelectionModelOptions(
+ _fbb,
+ _strip_unpaired_brackets,
+ _symmetry_context_size);
+}
+
+inline ClassificationModelOptionsT *ClassificationModelOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new ClassificationModelOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void ClassificationModelOptions::UnPackTo(ClassificationModelOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = phone_min_num_digits(); _o->phone_min_num_digits = _e; };
+ { auto _e = phone_max_num_digits(); _o->phone_max_num_digits = _e; };
+}
+
+inline flatbuffers::Offset<ClassificationModelOptions> ClassificationModelOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ClassificationModelOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateClassificationModelOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<ClassificationModelOptions> CreateClassificationModelOptions(flatbuffers::FlatBufferBuilder &_fbb, const ClassificationModelOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ClassificationModelOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _phone_min_num_digits = _o->phone_min_num_digits;
+ auto _phone_max_num_digits = _o->phone_max_num_digits;
+ return libtextclassifier2::CreateClassificationModelOptions(
+ _fbb,
+ _phone_min_num_digits,
+ _phone_max_num_digits);
+}
+
+inline RegexModelOptionsT *RegexModelOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new RegexModelOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void RegexModelOptions::UnPackTo(RegexModelOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = patterns(); if (_e) { _o->patterns.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->patterns[_i] = std::unique_ptr<libtextclassifier2::RegexModelOptions_::PatternT>(_e->Get(_i)->UnPack(_resolver)); } } };
+}
+
+inline flatbuffers::Offset<RegexModelOptions> RegexModelOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const RegexModelOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateRegexModelOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<RegexModelOptions> CreateRegexModelOptions(flatbuffers::FlatBufferBuilder &_fbb, const RegexModelOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const RegexModelOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _patterns = _o->patterns.size() ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::RegexModelOptions_::Pattern>> (_o->patterns.size(), [](size_t i, _VectorArgs *__va) { return CreatePattern(*__va->__fbb, __va->__o->patterns[i].get(), __va->__rehasher); }, &_va ) : 0;
+ return libtextclassifier2::CreateRegexModelOptions(
+ _fbb,
+ _patterns);
+}
+
+namespace RegexModelOptions_ {
+
+inline PatternT *Pattern::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new PatternT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void Pattern::UnPackTo(PatternT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = collection_name(); if (_e) _o->collection_name = _e->str(); };
+ { auto _e = pattern(); if (_e) _o->pattern = _e->str(); };
+}
+
+inline flatbuffers::Offset<Pattern> Pattern::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PatternT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePattern(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<Pattern> CreatePattern(flatbuffers::FlatBufferBuilder &_fbb, const PatternT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PatternT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _collection_name = _o->collection_name.empty() ? 0 : _fbb.CreateString(_o->collection_name);
+ auto _pattern = _o->pattern.empty() ? 0 : _fbb.CreateString(_o->pattern);
+ return libtextclassifier2::RegexModelOptions_::CreatePattern(
+ _fbb,
+ _collection_name,
+ _pattern);
+}
+
+} // namespace RegexModelOptions_
+
+inline StructuredRegexModelT *StructuredRegexModel::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new StructuredRegexModelT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void StructuredRegexModel::UnPackTo(StructuredRegexModelT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = patterns(); if (_e) { _o->patterns.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->patterns[_i] = std::unique_ptr<libtextclassifier2::StructuredRegexModel_::StructuredPatternT>(_e->Get(_i)->UnPack(_resolver)); } } };
+}
+
+inline flatbuffers::Offset<StructuredRegexModel> StructuredRegexModel::Pack(flatbuffers::FlatBufferBuilder &_fbb, const StructuredRegexModelT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateStructuredRegexModel(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<StructuredRegexModel> CreateStructuredRegexModel(flatbuffers::FlatBufferBuilder &_fbb, const StructuredRegexModelT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const StructuredRegexModelT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _patterns = _o->patterns.size() ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::StructuredRegexModel_::StructuredPattern>> (_o->patterns.size(), [](size_t i, _VectorArgs *__va) { return CreateStructuredPattern(*__va->__fbb, __va->__o->patterns[i].get(), __va->__rehasher); }, &_va ) : 0;
+ return libtextclassifier2::CreateStructuredRegexModel(
+ _fbb,
+ _patterns);
+}
+
+namespace StructuredRegexModel_ {
+
+inline StructuredPatternT *StructuredPattern::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new StructuredPatternT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void StructuredPattern::UnPackTo(StructuredPatternT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = pattern(); if (_e) _o->pattern = _e->str(); };
+ { auto _e = node_names(); if (_e) { _o->node_names.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->node_names[_i] = _e->Get(_i)->str(); } } };
+}
+
+inline flatbuffers::Offset<StructuredPattern> StructuredPattern::Pack(flatbuffers::FlatBufferBuilder &_fbb, const StructuredPatternT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateStructuredPattern(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<StructuredPattern> CreateStructuredPattern(flatbuffers::FlatBufferBuilder &_fbb, const StructuredPatternT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const StructuredPatternT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _pattern = _o->pattern.empty() ? 0 : _fbb.CreateString(_o->pattern);
+ auto _node_names = _o->node_names.size() ? _fbb.CreateVectorOfStrings(_o->node_names) : 0;
+ return libtextclassifier2::StructuredRegexModel_::CreateStructuredPattern(
+ _fbb,
+ _pattern,
+ _node_names);
+}
+
+} // namespace StructuredRegexModel_
+
+inline ModelT *Model::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new ModelT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void Model::UnPackTo(ModelT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = language(); if (_e) _o->language = _e->str(); };
+ { auto _e = version(); _o->version = _e; };
+ { auto _e = selection_feature_options(); if (_e) _o->selection_feature_options = std::unique_ptr<FeatureProcessorOptionsT>(_e->UnPack(_resolver)); };
+ { auto _e = classification_feature_options(); if (_e) _o->classification_feature_options = std::unique_ptr<FeatureProcessorOptionsT>(_e->UnPack(_resolver)); };
+ { auto _e = selection_model(); if (_e) { _o->selection_model.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->selection_model[_i] = _e->Get(_i); } } };
+ { auto _e = classification_model(); if (_e) { _o->classification_model.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->classification_model[_i] = _e->Get(_i); } } };
+ { auto _e = embedding_model(); if (_e) { _o->embedding_model.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->embedding_model[_i] = _e->Get(_i); } } };
+ { auto _e = regex_options(); if (_e) _o->regex_options = std::unique_ptr<RegexModelOptionsT>(_e->UnPack(_resolver)); };
+ { auto _e = selection_options(); if (_e) _o->selection_options = std::unique_ptr<SelectionModelOptionsT>(_e->UnPack(_resolver)); };
+ { auto _e = classification_options(); if (_e) _o->classification_options = std::unique_ptr<ClassificationModelOptionsT>(_e->UnPack(_resolver)); };
+ { auto _e = regex_model(); if (_e) _o->regex_model = std::unique_ptr<StructuredRegexModelT>(_e->UnPack(_resolver)); };
+}
+
+inline flatbuffers::Offset<Model> Model::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ModelT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateModel(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<Model> CreateModel(flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ModelT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _language = _o->language.empty() ? 0 : _fbb.CreateString(_o->language);
+ auto _version = _o->version;
+ auto _selection_feature_options = _o->selection_feature_options ? CreateFeatureProcessorOptions(_fbb, _o->selection_feature_options.get(), _rehasher) : 0;
+ auto _classification_feature_options = _o->classification_feature_options ? CreateFeatureProcessorOptions(_fbb, _o->classification_feature_options.get(), _rehasher) : 0;
+ auto _selection_model = _o->selection_model.size() ? _fbb.CreateVector(_o->selection_model) : 0;
+ auto _classification_model = _o->classification_model.size() ? _fbb.CreateVector(_o->classification_model) : 0;
+ auto _embedding_model = _o->embedding_model.size() ? _fbb.CreateVector(_o->embedding_model) : 0;
+ auto _regex_options = _o->regex_options ? CreateRegexModelOptions(_fbb, _o->regex_options.get(), _rehasher) : 0;
+ auto _selection_options = _o->selection_options ? CreateSelectionModelOptions(_fbb, _o->selection_options.get(), _rehasher) : 0;
+ auto _classification_options = _o->classification_options ? CreateClassificationModelOptions(_fbb, _o->classification_options.get(), _rehasher) : 0;
+ auto _regex_model = _o->regex_model ? CreateStructuredRegexModel(_fbb, _o->regex_model.get(), _rehasher) : 0;
+ return libtextclassifier2::CreateModel(
+ _fbb,
+ _language,
+ _version,
+ _selection_feature_options,
+ _classification_feature_options,
+ _selection_model,
+ _classification_model,
+ _embedding_model,
+ _regex_options,
+ _selection_options,
+ _classification_options,
+ _regex_model);
+}
+
+inline TokenizationCodepointRangeT *TokenizationCodepointRange::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new TokenizationCodepointRangeT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void TokenizationCodepointRange::UnPackTo(TokenizationCodepointRangeT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = start(); _o->start = _e; };
+ { auto _e = end(); _o->end = _e; };
+ { auto _e = role(); _o->role = _e; };
+ { auto _e = script_id(); _o->script_id = _e; };
+}
+
+inline flatbuffers::Offset<TokenizationCodepointRange> TokenizationCodepointRange::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TokenizationCodepointRangeT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateTokenizationCodepointRange(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<TokenizationCodepointRange> CreateTokenizationCodepointRange(flatbuffers::FlatBufferBuilder &_fbb, const TokenizationCodepointRangeT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const TokenizationCodepointRangeT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _start = _o->start;
+ auto _end = _o->end;
+ auto _role = _o->role;
+ auto _script_id = _o->script_id;
+ return libtextclassifier2::CreateTokenizationCodepointRange(
+ _fbb,
+ _start,
+ _end,
+ _role,
+ _script_id);
+}
+
+inline FeatureProcessorOptionsT *FeatureProcessorOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new FeatureProcessorOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void FeatureProcessorOptions::UnPackTo(FeatureProcessorOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = num_buckets(); _o->num_buckets = _e; };
+ { auto _e = embedding_size(); _o->embedding_size = _e; };
+ { auto _e = context_size(); _o->context_size = _e; };
+ { auto _e = max_selection_span(); _o->max_selection_span = _e; };
+ { auto _e = chargram_orders(); if (_e) { _o->chargram_orders.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->chargram_orders[_i] = _e->Get(_i); } } };
+ { auto _e = max_word_length(); _o->max_word_length = _e; };
+ { auto _e = unicode_aware_features(); _o->unicode_aware_features = _e; };
+ { auto _e = extract_case_feature(); _o->extract_case_feature = _e; };
+ { auto _e = extract_selection_mask_feature(); _o->extract_selection_mask_feature = _e; };
+ { auto _e = regexp_feature(); if (_e) { _o->regexp_feature.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->regexp_feature[_i] = _e->Get(_i)->str(); } } };
+ { auto _e = remap_digits(); _o->remap_digits = _e; };
+ { auto _e = lowercase_tokens(); _o->lowercase_tokens = _e; };
+ { auto _e = selection_reduced_output_space(); _o->selection_reduced_output_space = _e; };
+ { auto _e = collections(); if (_e) { _o->collections.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->collections[_i] = _e->Get(_i)->str(); } } };
+ { auto _e = default_collection(); _o->default_collection = _e; };
+ { auto _e = only_use_line_with_click(); _o->only_use_line_with_click = _e; };
+ { auto _e = split_tokens_on_selection_boundaries(); _o->split_tokens_on_selection_boundaries = _e; };
+ { auto _e = tokenization_codepoint_config(); if (_e) { _o->tokenization_codepoint_config.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->tokenization_codepoint_config[_i] = std::unique_ptr<TokenizationCodepointRangeT>(_e->Get(_i)->UnPack(_resolver)); } } };
+ { auto _e = center_token_selection_method(); _o->center_token_selection_method = _e; };
+ { auto _e = snap_label_span_boundaries_to_containing_tokens(); _o->snap_label_span_boundaries_to_containing_tokens = _e; };
+ { auto _e = supported_codepoint_ranges(); if (_e) { _o->supported_codepoint_ranges.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->supported_codepoint_ranges[_i] = std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::CodepointRangeT>(_e->Get(_i)->UnPack(_resolver)); } } };
+ { auto _e = internal_tokenizer_codepoint_ranges(); if (_e) { _o->internal_tokenizer_codepoint_ranges.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->internal_tokenizer_codepoint_ranges[_i] = std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::CodepointRangeT>(_e->Get(_i)->UnPack(_resolver)); } } };
+ { auto _e = min_supported_codepoint_ratio(); _o->min_supported_codepoint_ratio = _e; };
+ { auto _e = feature_version(); _o->feature_version = _e; };
+ { auto _e = tokenization_type(); _o->tokenization_type = _e; };
+ { auto _e = icu_preserve_whitespace_tokens(); _o->icu_preserve_whitespace_tokens = _e; };
+ { auto _e = ignored_span_boundary_codepoints(); if (_e) { _o->ignored_span_boundary_codepoints.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->ignored_span_boundary_codepoints[_i] = _e->Get(_i); } } };
+ { auto _e = click_random_token_in_selection(); _o->click_random_token_in_selection = _e; };
+ { auto _e = alternative_collection_map(); if (_e) { _o->alternative_collection_map.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->alternative_collection_map[_i] = std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::CollectionMapEntryT>(_e->Get(_i)->UnPack(_resolver)); } } };
+ { auto _e = bounds_sensitive_features(); if (_e) _o->bounds_sensitive_features = std::unique_ptr<libtextclassifier2::FeatureProcessorOptions_::BoundsSensitiveFeaturesT>(_e->UnPack(_resolver)); };
+ { auto _e = split_selection_candidates(); _o->split_selection_candidates = _e; };
+ { auto _e = allowed_chargrams(); if (_e) { _o->allowed_chargrams.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->allowed_chargrams[_i] = _e->Get(_i)->str(); } } };
+ { auto _e = tokenize_on_script_change(); _o->tokenize_on_script_change = _e; };
+}
+
+inline flatbuffers::Offset<FeatureProcessorOptions> FeatureProcessorOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FeatureProcessorOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateFeatureProcessorOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<FeatureProcessorOptions> CreateFeatureProcessorOptions(flatbuffers::FlatBufferBuilder &_fbb, const FeatureProcessorOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const FeatureProcessorOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _num_buckets = _o->num_buckets;
+ auto _embedding_size = _o->embedding_size;
+ auto _context_size = _o->context_size;
+ auto _max_selection_span = _o->max_selection_span;
+ auto _chargram_orders = _o->chargram_orders.size() ? _fbb.CreateVector(_o->chargram_orders) : 0;
+ auto _max_word_length = _o->max_word_length;
+ auto _unicode_aware_features = _o->unicode_aware_features;
+ auto _extract_case_feature = _o->extract_case_feature;
+ auto _extract_selection_mask_feature = _o->extract_selection_mask_feature;
+ auto _regexp_feature = _o->regexp_feature.size() ? _fbb.CreateVectorOfStrings(_o->regexp_feature) : 0;
+ auto _remap_digits = _o->remap_digits;
+ auto _lowercase_tokens = _o->lowercase_tokens;
+ auto _selection_reduced_output_space = _o->selection_reduced_output_space;
+ auto _collections = _o->collections.size() ? _fbb.CreateVectorOfStrings(_o->collections) : 0;
+ auto _default_collection = _o->default_collection;
+ auto _only_use_line_with_click = _o->only_use_line_with_click;
+ auto _split_tokens_on_selection_boundaries = _o->split_tokens_on_selection_boundaries;
+ auto _tokenization_codepoint_config = _o->tokenization_codepoint_config.size() ? _fbb.CreateVector<flatbuffers::Offset<TokenizationCodepointRange>> (_o->tokenization_codepoint_config.size(), [](size_t i, _VectorArgs *__va) { return CreateTokenizationCodepointRange(*__va->__fbb, __va->__o->tokenization_codepoint_config[i].get(), __va->__rehasher); }, &_va ) : 0;
+ auto _center_token_selection_method = _o->center_token_selection_method;
+ auto _snap_label_span_boundaries_to_containing_tokens = _o->snap_label_span_boundaries_to_containing_tokens;
+ auto _supported_codepoint_ranges = _o->supported_codepoint_ranges.size() ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> (_o->supported_codepoint_ranges.size(), [](size_t i, _VectorArgs *__va) { return CreateCodepointRange(*__va->__fbb, __va->__o->supported_codepoint_ranges[i].get(), __va->__rehasher); }, &_va ) : 0;
+ auto _internal_tokenizer_codepoint_ranges = _o->internal_tokenizer_codepoint_ranges.size() ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CodepointRange>> (_o->internal_tokenizer_codepoint_ranges.size(), [](size_t i, _VectorArgs *__va) { return CreateCodepointRange(*__va->__fbb, __va->__o->internal_tokenizer_codepoint_ranges[i].get(), __va->__rehasher); }, &_va ) : 0;
+ auto _min_supported_codepoint_ratio = _o->min_supported_codepoint_ratio;
+ auto _feature_version = _o->feature_version;
+ auto _tokenization_type = _o->tokenization_type;
+ auto _icu_preserve_whitespace_tokens = _o->icu_preserve_whitespace_tokens;
+ auto _ignored_span_boundary_codepoints = _o->ignored_span_boundary_codepoints.size() ? _fbb.CreateVector(_o->ignored_span_boundary_codepoints) : 0;
+ auto _click_random_token_in_selection = _o->click_random_token_in_selection;
+ auto _alternative_collection_map = _o->alternative_collection_map.size() ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::FeatureProcessorOptions_::CollectionMapEntry>> (_o->alternative_collection_map.size(), [](size_t i, _VectorArgs *__va) { return CreateCollectionMapEntry(*__va->__fbb, __va->__o->alternative_collection_map[i].get(), __va->__rehasher); }, &_va ) : 0;
+ auto _bounds_sensitive_features = _o->bounds_sensitive_features ? CreateBoundsSensitiveFeatures(_fbb, _o->bounds_sensitive_features.get(), _rehasher) : 0;
+ auto _split_selection_candidates = _o->split_selection_candidates;
+ auto _allowed_chargrams = _o->allowed_chargrams.size() ? _fbb.CreateVectorOfStrings(_o->allowed_chargrams) : 0;
+ auto _tokenize_on_script_change = _o->tokenize_on_script_change;
+ return libtextclassifier2::CreateFeatureProcessorOptions(
+ _fbb,
+ _num_buckets,
+ _embedding_size,
+ _context_size,
+ _max_selection_span,
+ _chargram_orders,
+ _max_word_length,
+ _unicode_aware_features,
+ _extract_case_feature,
+ _extract_selection_mask_feature,
+ _regexp_feature,
+ _remap_digits,
+ _lowercase_tokens,
+ _selection_reduced_output_space,
+ _collections,
+ _default_collection,
+ _only_use_line_with_click,
+ _split_tokens_on_selection_boundaries,
+ _tokenization_codepoint_config,
+ _center_token_selection_method,
+ _snap_label_span_boundaries_to_containing_tokens,
+ _supported_codepoint_ranges,
+ _internal_tokenizer_codepoint_ranges,
+ _min_supported_codepoint_ratio,
+ _feature_version,
+ _tokenization_type,
+ _icu_preserve_whitespace_tokens,
+ _ignored_span_boundary_codepoints,
+ _click_random_token_in_selection,
+ _alternative_collection_map,
+ _bounds_sensitive_features,
+ _split_selection_candidates,
+ _allowed_chargrams,
+ _tokenize_on_script_change);
+}
+
+namespace FeatureProcessorOptions_ {
+
+inline CodepointRangeT *CodepointRange::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new CodepointRangeT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void CodepointRange::UnPackTo(CodepointRangeT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = start(); _o->start = _e; };
+ { auto _e = end(); _o->end = _e; };
+}
+
+inline flatbuffers::Offset<CodepointRange> CodepointRange::Pack(flatbuffers::FlatBufferBuilder &_fbb, const CodepointRangeT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateCodepointRange(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<CodepointRange> CreateCodepointRange(flatbuffers::FlatBufferBuilder &_fbb, const CodepointRangeT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const CodepointRangeT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _start = _o->start;
+ auto _end = _o->end;
+ return libtextclassifier2::FeatureProcessorOptions_::CreateCodepointRange(
+ _fbb,
+ _start,
+ _end);
+}
+
+inline CollectionMapEntryT *CollectionMapEntry::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new CollectionMapEntryT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void CollectionMapEntry::UnPackTo(CollectionMapEntryT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = key(); if (_e) _o->key = _e->str(); };
+ { auto _e = value(); if (_e) _o->value = _e->str(); };
+}
+
+inline flatbuffers::Offset<CollectionMapEntry> CollectionMapEntry::Pack(flatbuffers::FlatBufferBuilder &_fbb, const CollectionMapEntryT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateCollectionMapEntry(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<CollectionMapEntry> CreateCollectionMapEntry(flatbuffers::FlatBufferBuilder &_fbb, const CollectionMapEntryT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const CollectionMapEntryT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _key = _o->key.empty() ? 0 : _fbb.CreateString(_o->key);
+ auto _value = _o->value.empty() ? 0 : _fbb.CreateString(_o->value);
+ return libtextclassifier2::FeatureProcessorOptions_::CreateCollectionMapEntry(
+ _fbb,
+ _key,
+ _value);
+}
+
+inline BoundsSensitiveFeaturesT *BoundsSensitiveFeatures::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new BoundsSensitiveFeaturesT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void BoundsSensitiveFeatures::UnPackTo(BoundsSensitiveFeaturesT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = enabled(); _o->enabled = _e; };
+ { auto _e = num_tokens_before(); _o->num_tokens_before = _e; };
+ { auto _e = num_tokens_inside_left(); _o->num_tokens_inside_left = _e; };
+ { auto _e = num_tokens_inside_right(); _o->num_tokens_inside_right = _e; };
+ { auto _e = num_tokens_after(); _o->num_tokens_after = _e; };
+ { auto _e = include_inside_bag(); _o->include_inside_bag = _e; };
+ { auto _e = include_inside_length(); _o->include_inside_length = _e; };
+}
+
+inline flatbuffers::Offset<BoundsSensitiveFeatures> BoundsSensitiveFeatures::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BoundsSensitiveFeaturesT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateBoundsSensitiveFeatures(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<BoundsSensitiveFeatures> CreateBoundsSensitiveFeatures(flatbuffers::FlatBufferBuilder &_fbb, const BoundsSensitiveFeaturesT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const BoundsSensitiveFeaturesT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _enabled = _o->enabled;
+ auto _num_tokens_before = _o->num_tokens_before;
+ auto _num_tokens_inside_left = _o->num_tokens_inside_left;
+ auto _num_tokens_inside_right = _o->num_tokens_inside_right;
+ auto _num_tokens_after = _o->num_tokens_after;
+ auto _include_inside_bag = _o->include_inside_bag;
+ auto _include_inside_length = _o->include_inside_length;
+ return libtextclassifier2::FeatureProcessorOptions_::CreateBoundsSensitiveFeatures(
+ _fbb,
+ _enabled,
+ _num_tokens_before,
+ _num_tokens_inside_left,
+ _num_tokens_inside_right,
+ _num_tokens_after,
+ _include_inside_bag,
+ _include_inside_length);
+}
+
+} // namespace FeatureProcessorOptions_
+} // namespace libtextclassifier2
+
+#endif // FLATBUFFERS_GENERATED_MODEL_LIBTEXTCLASSIFIER2_FEATUREPROCESSOROPTIONS__H_
diff --git a/models/textclassifier.en.model b/models/textclassifier.en.model
new file mode 100644
index 0000000..9814c93
--- /dev/null
+++ b/models/textclassifier.en.model
Binary files differ
diff --git a/models/textclassifier.langid.model b/models/textclassifier.langid.model
deleted file mode 100644
index 6b68223..0000000
--- a/models/textclassifier.langid.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.smartselection.ar.model b/models/textclassifier.smartselection.ar.model
deleted file mode 100644
index f22fe0f..0000000
--- a/models/textclassifier.smartselection.ar.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.smartselection.de.model b/models/textclassifier.smartselection.de.model
deleted file mode 100644
index 5eb3181..0000000
--- a/models/textclassifier.smartselection.de.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.smartselection.en.model b/models/textclassifier.smartselection.en.model
deleted file mode 100644
index 7af0897..0000000
--- a/models/textclassifier.smartselection.en.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.smartselection.es.model b/models/textclassifier.smartselection.es.model
deleted file mode 100644
index 9ea6af9..0000000
--- a/models/textclassifier.smartselection.es.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.smartselection.fr.model b/models/textclassifier.smartselection.fr.model
deleted file mode 100644
index 3ff5416..0000000
--- a/models/textclassifier.smartselection.fr.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.smartselection.it.model b/models/textclassifier.smartselection.it.model
deleted file mode 100644
index 377fff5..0000000
--- a/models/textclassifier.smartselection.it.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.smartselection.ja.model b/models/textclassifier.smartselection.ja.model
deleted file mode 100644
index 53fce93..0000000
--- a/models/textclassifier.smartselection.ja.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.smartselection.ko.model b/models/textclassifier.smartselection.ko.model
deleted file mode 100644
index 6bcac15..0000000
--- a/models/textclassifier.smartselection.ko.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.smartselection.nl.model b/models/textclassifier.smartselection.nl.model
deleted file mode 100644
index c80dff6..0000000
--- a/models/textclassifier.smartselection.nl.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.smartselection.pl.model b/models/textclassifier.smartselection.pl.model
deleted file mode 100644
index 3379c63..0000000
--- a/models/textclassifier.smartselection.pl.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.smartselection.pt.model b/models/textclassifier.smartselection.pt.model
deleted file mode 100644
index 4378c8f..0000000
--- a/models/textclassifier.smartselection.pt.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.smartselection.ru.model b/models/textclassifier.smartselection.ru.model
deleted file mode 100644
index 0763b33..0000000
--- a/models/textclassifier.smartselection.ru.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.smartselection.th.model b/models/textclassifier.smartselection.th.model
deleted file mode 100644
index 521fea0..0000000
--- a/models/textclassifier.smartselection.th.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.smartselection.tr.model b/models/textclassifier.smartselection.tr.model
deleted file mode 100644
index 0177175..0000000
--- a/models/textclassifier.smartselection.tr.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.smartselection.zh-Hant.model b/models/textclassifier.smartselection.zh-Hant.model
deleted file mode 100644
index ec03c26..0000000
--- a/models/textclassifier.smartselection.zh-Hant.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.smartselection.zh.model b/models/textclassifier.smartselection.zh.model
deleted file mode 100644
index acc6142..0000000
--- a/models/textclassifier.smartselection.zh.model
+++ /dev/null
Binary files differ
diff --git a/regex-base.cc b/regex-base.cc
new file mode 100644
index 0000000..790e453
--- /dev/null
+++ b/regex-base.cc
@@ -0,0 +1,61 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "regex-base.h"
+
+namespace libtextclassifier2 {
+
+Rules::Rules(const std::string &locale) : locale_(locale) {}
+
+FlatBufferRules::FlatBufferRules(const std::string &locale, const Model *model)
+ : Rules(locale), model_(model), rule_cache_() {
+ for (int i = 0; i < model_->regex_model()->patterns()->Length(); i++) {
+ auto regex_model = model_->regex_model()->patterns()->Get(i);
+ for (int j = 0; j < regex_model->node_names()->Length(); j++) {
+ std::string name = regex_model->node_names()->Get(j)->str();
+ rule_cache_[name].push_back(regex_model->pattern());
+ }
+ }
+}
+
+bool FlatBufferRules::RuleForName(const std::string &name,
+ std::string *out) const {
+ const auto match = rule_cache_.find(name);
+ if (match != rule_cache_.end()) {
+ if (match->second.size() != 1) {
+ TC_LOG(ERROR) << "Found " << match->second.size()
+ << " rule where only 1 was expected.";
+ return false;
+ }
+ *out = match->second[0]->str();
+ return true;
+ }
+ return false;
+}
+
+const std::vector<std::string> FlatBufferRules::RulesForName(
+ const std::string &name) const {
+ std::vector<std::string> results;
+ const auto match = rule_cache_.find(name);
+ if (match != rule_cache_.end()) {
+ for (auto &s : match->second) {
+ results.push_back(s->str());
+ }
+ }
+ return results;
+}
+
+} // namespace libtextclassifier2
diff --git a/regex-base.h b/regex-base.h
new file mode 100644
index 0000000..5856eba
--- /dev/null
+++ b/regex-base.h
@@ -0,0 +1,472 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_REGEX_BASE_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_REGEX_BASE_H_
+
+#include <iostream>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "model_generated.h"
+#include "util/base/logging.h"
+#include "util/memory/mmap.h"
+#include "unicode/regex.h"
+#include "unicode/uchar.h"
+
+namespace libtextclassifier2 {
+
+// Encapsulates the start and end of a region of a string of a entity that
+// has been mapping to an element of type T
+template <class T>
+class SpanResult {
+ private:
+ const T data_;
+ const int start_;
+ const int end_;
+
+ public:
+ SpanResult(int start, int end, T data)
+ : data_(data), start_(start), end_(end) {}
+
+ const T &Data() const { return data_; }
+
+ int Start() const { return start_; }
+
+ int End() const { return end_; }
+};
+
+// Interface supplying class that provides a protocol for encapsulating a unit
+// of text processing that can match against strings and also extract values of
+// SpanResults of type T. Implemenations are expected to be thread-safe.
+template <class T>
+class Node {
+ public:
+ typedef SpanResult<T> Result;
+ typedef std::vector<Result> Results;
+
+ // Returns a boolean value if Node can find a region of the string which
+ // matches it's logic.
+ virtual bool Matches(const std::string &input) const = 0;
+
+ // Populates the supplied Results vector with the values obtained by
+ // extracted on any matching elements of the string.
+ // Returns true if the processing yielded no error.
+ // Returns false if their was an error processing the string.
+ virtual bool Extract(const std::string &input, Results *result) const = 0;
+
+ virtual ~Node() = default;
+};
+
+class Rules {
+ public:
+ explicit Rules(const std::string &locale);
+ virtual ~Rules() = default;
+ virtual const std::vector<std::string> RulesForName(
+ const std::string &name) const = 0;
+ virtual bool RuleForName(const std::string &name, std::string *out) const = 0;
+
+ protected:
+ const std::string locale_;
+};
+
+class FlatBufferRules : public Rules {
+ public:
+ FlatBufferRules(const std::string &locale, const Model *model);
+ const std::vector<std::string> RulesForName(
+ const std::string &name) const override;
+ bool RuleForName(const std::string &name, std::string *out) const override;
+
+ protected:
+ const Model *model_;
+ std::unordered_map<std::string, std::vector<const flatbuffers::String *>>
+ rule_cache_;
+};
+
+// Abstract supplying class that provides a protocol for encapsulating a unit
+// of regular expression processing that can match against strings and also
+// extract values of SpanResults of type T. Implemenations are expected to be
+// thread-safe.
+// Implementors of this class are expected to call Init() prior to calling
+// Extract() or Match().
+template <class T>
+class RegexNode : public Node<T> {
+ public:
+ typedef typename Node<T>::Result Result;
+ typedef typename Node<T>::Results Results;
+
+ RegexNode() : pattern_() {}
+
+ // A string representation of the regular expression used in this processing.
+ const std::string pattern() const {
+ std::string regex;
+ pattern_->pattern().toUTF8String(regex);
+ return regex;
+ }
+
+ bool Matches(const std::string &input) const override {
+ UErrorCode status = U_ZERO_ERROR;
+ const icu::UnicodeString unicode_context(input.c_str(), input.size(),
+ "utf-8");
+ std::unique_ptr<icu::RegexMatcher> matcher(
+ pattern_->matcher(unicode_context, status));
+
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "failed to load regex '" << pattern()
+ << "': " << u_errorName(status);
+ return false;
+ }
+ const bool res = matcher->find(status);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "failed to find with regex '" << pattern()
+ << "': " << u_errorName(status);
+ return false;
+ }
+ return res;
+ }
+
+ bool Extract(const std::string &input, Results *result) const override = 0;
+ ~RegexNode() override {}
+
+ protected:
+ bool Init(const std::string ®ex) {
+ UErrorCode status = U_ZERO_ERROR;
+ pattern_ = std::unique_ptr<icu::RegexPattern>(
+ icu::RegexPattern::compile(regex.c_str(), UREGEX_MULTILINE, status));
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "failed to compile regex '" << pattern()
+ << "': " << u_errorName(status);
+ return false;
+ }
+ return true;
+ }
+
+ std::unique_ptr<icu::RegexPattern> pattern_;
+};
+
+// Class that encapsulates complex object matching and extract values of
+// from multiple sub-patterns, each of which is mapped to a string describing
+// the value.
+// Thread-safe.
+template <class T>
+class CompoundNode : public RegexNode<std::unordered_map<std::string, T>> {
+ public:
+ typedef std::unordered_map<std::string, std::unique_ptr<const Node<T>>>
+ Extractors;
+ typedef typename RegexNode<std::unordered_map<std::string, T>>::Result Result;
+ typedef
+ typename RegexNode<std::unordered_map<std::string, T>>::Results Results;
+
+ static std::unique_ptr<CompoundNode<T>> Instance(const std::string &rule,
+ const Extractors &extractors,
+ const Rules &rules) {
+ std::unique_ptr<CompoundNode<T>> node(new CompoundNode(extractors));
+ if (!node->Init(rule, rules)) {
+ return nullptr;
+ }
+ return node;
+ }
+
+ bool Extract(const std::string &input, Results *result) const override {
+ UErrorCode status = U_ZERO_ERROR;
+ const icu::UnicodeString unicode_context(input.c_str(), input.size(),
+ "utf-8");
+ std::unique_ptr<icu::RegexMatcher> matcher(
+ RegexNode<std::unordered_map<std::string, T>>::pattern_->matcher(
+ unicode_context, status));
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "error loading regex '"
+ << RegexNode<std::unordered_map<std::string, T>>::pattern()
+ << "'" << u_errorName(status);
+ return false;
+ }
+ while (matcher->find() && U_SUCCESS(status)) {
+ const int start = matcher->start(status);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR)
+ << "failed to demarshall start '"
+ << RegexNode<std::unordered_map<std::string, T>>::pattern() << "'"
+ << u_errorName(status);
+ return false;
+ }
+
+ const int end = matcher->end(status);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR)
+ << "failed to demarshall end '"
+ << RegexNode<std::unordered_map<std::string, T>>::pattern() << "'"
+ << u_errorName(status);
+ return false;
+ }
+
+ std::unordered_map<std::string, T> extraction;
+ for (auto name : groupnames_) {
+ const int group_number = matcher->pattern().groupNumberFromName(
+ icu::UnicodeString(name.c_str(), name.size(), "utf-8"), status);
+ if (U_FAILURE(status)) {
+ // We expect this to happen for optional named groups.
+ continue;
+ }
+
+ std::string capture;
+ matcher->group(group_number, status).toUTF8String(capture);
+ std::vector<SpanResult<T>> sub_result;
+
+ if (!extractors_->find(name)->second->Extract(capture, &sub_result)) {
+ return false;
+ }
+ if (!sub_result.empty()) {
+ extraction[name] = sub_result[0].Data();
+ }
+ }
+ result->push_back(Result(start, end, extraction));
+ }
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "failed to extract '"
+ << RegexNode<std::unordered_map<std::string, T>>::pattern()
+ << "':" << u_errorName(status);
+ return false;
+ }
+ return true;
+ }
+
+ private:
+ const Extractors *extractors_;
+ std::vector<std::string> groupnames_;
+
+ explicit CompoundNode(const Extractors &extractors)
+ : extractors_(&extractors), groupnames_() {}
+
+ bool Init(const std::string &rule, const Rules &rules) {
+ static const icu::RegexPattern *pattern = []() {
+ UErrorCode status = U_ZERO_ERROR;
+ return icu::RegexPattern::compile("[?]<([A-Z_]+)>", UREGEX_MULTILINE,
+ status);
+ }();
+ if (!pattern) {
+ return false;
+ }
+ std::string source = rule;
+ UErrorCode status = U_ZERO_ERROR;
+ const icu::UnicodeString unicode_context(source.c_str(), source.size(),
+ "utf-8");
+ std::unique_ptr<icu::RegexMatcher> matcher(
+ pattern->matcher(unicode_context, status));
+
+ std::unordered_map<std::string, std::string> swaps;
+ while (matcher->find(status)) {
+ std::string name;
+ matcher->group(1, status).toUTF8String(name);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "failed to demarshall name " << u_errorName(status);
+ return false;
+ }
+ groupnames_.push_back(name);
+ }
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "failed to execute the regex properly"
+ << u_errorName(status);
+ return false;
+ }
+ for (auto swap : swaps) {
+ std::string::size_type n = 0;
+ while ((n = source.find(swap.first, n)) != std::string::npos) {
+ source.replace(n, swap.first.size(), swap.second);
+ n += swap.second.size();
+ }
+ }
+ return RegexNode<std::unordered_map<std::string, T>>::Init(source);
+ }
+};
+
+// Class that managed multiple alternate Nodes that all yield the same result
+// and can be used interchangeably. It returns reults only from the first
+// successfully matching child node - this Nodes can be given to in in
+// order of precedence.
+// Thread-safe.
+template <class T>
+class OrNode : public Node<T> {
+ public:
+ typedef typename Node<T>::Result Result;
+ typedef typename Node<T>::Results Results;
+
+ explicit OrNode(std::vector<std::unique_ptr<const Node<T>>> alternatives)
+ : alternatives_(std::move(alternatives)) {}
+
+ bool Extract(const std::string &input, Results *result) const override {
+ for (auto &alternative : alternatives_) {
+ typename Node<T>::Results alternative_result;
+ // NOTE: We are explicitly choosing to fall through errors on these
+ // alternatives to try a lesser match instead of bailing on the user.
+ if (alternative->Extract(input, &alternative_result)) {
+ if (!alternative_result.empty()) {
+ for (auto &s : alternative_result) {
+ result->push_back(s);
+ }
+ return true;
+ }
+ }
+ }
+ return true;
+ }
+
+ bool Matches(const std::string &input) const override {
+ for (auto &alternative : alternatives_) {
+ if (alternative->Matches(input)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ private:
+ std::vector<std::unique_ptr<const Node<T>>> alternatives_;
+};
+
+// Class that managed multiple alternate RegexNodes that all yield the same
+// result and can be used interchangeably. It returns reults only from the first
+// successfully matching child node - this Nodes can be given to in in
+// order of precedence.
+// Thread-safe.
+template <class T>
+class OrRegexNode : public RegexNode<T> {
+ public:
+ typedef typename RegexNode<T>::Result Result;
+ typedef typename RegexNode<T>::Results Results;
+
+ bool Extract(const std::string &input, Results *result) const override {
+ for (auto &alternative : alternatives_) {
+ typename RegexNode<T>::Results alternative_result;
+ // NOTE: we are explicitly choosing to fall through errors on these
+ // alternatives to try a lesser match instead of bailing on the user
+ if (alternative->Extract(input, &alternative_result)) {
+ if (!alternative_result.empty()) {
+ for (typename RegexNode<T>::Result &s : alternative_result) {
+ result->push_back(s);
+ }
+ return true;
+ }
+ }
+ }
+ return true;
+ }
+
+ protected:
+ std::vector<std::unique_ptr<const RegexNode<T>>> alternatives_;
+
+ explicit OrRegexNode(
+ std::vector<std::unique_ptr<const RegexNode<T>>> alternatives)
+ : alternatives_(std::move(alternatives)) {}
+
+ bool Init() {
+ std::string pattern;
+ for (int i = 0; i < alternatives_.size(); i++) {
+ if (i == 0) {
+ pattern = alternatives_[i]->pattern();
+ } else {
+ pattern += "|";
+ pattern += alternatives_[i]->pattern();
+ }
+ }
+ return RegexNode<T>::Init(pattern);
+ }
+};
+
+// Class that yields a constant value for any string that matches the input
+// Thread-safe.
+template <class T>
+class MappingNode : public RegexNode<T> {
+ public:
+ typedef RegexNode<T> Parent;
+ typedef typename Parent::Result Result;
+ typedef typename Parent::Results Results;
+
+ static std::unique_ptr<MappingNode<T>> Instance(const std::string &name,
+ const T &value,
+ const Rules &rules) {
+ std::unique_ptr<MappingNode<T>> node(new MappingNode<T>(value));
+ if (!node->Init(name, rules)) {
+ return nullptr;
+ }
+ return node;
+ }
+
+ bool Extract(const std::string &input, Results *result) const override {
+ UErrorCode status = U_ZERO_ERROR;
+
+ const icu::UnicodeString unicode_context(input.c_str(), input.size(),
+ "utf-8");
+ std::unique_ptr<icu::RegexMatcher> matcher(
+ Parent::pattern_->matcher(unicode_context, status));
+
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "error loading regex ";
+ return false;
+ }
+
+ while (matcher->find() && U_SUCCESS(status)) {
+ const int start = matcher->start(status);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "failed to demarshall start " << u_errorName(status);
+ return false;
+ }
+ const int end = matcher->end(status);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "failed to demarshall end " << u_errorName(status);
+ return false;
+ }
+ result->push_back(Result(start, end, value_));
+ }
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "failed to demarshall end " << u_errorName(status);
+ return false;
+ }
+ return true;
+ }
+
+ private:
+ explicit MappingNode(const T value) : value_(value) {}
+
+ const T value_;
+
+ bool Init(const std::string &name, const Rules &rules) {
+ std::string pattern;
+ if (!rules.RuleForName(name, &pattern)) {
+ TC_LOG(ERROR) << "failed to load rule for name '" << name << "'";
+ return false;
+ }
+ return RegexNode<T>::Init(pattern);
+ }
+};
+
+template <class T>
+bool BuildMappings(const Rules &rules,
+ const std::vector<std::pair<std::string, T>> &pairs,
+ std::vector<std::unique_ptr<const RegexNode<T>>> *mappings) {
+ for (auto &p : pairs) {
+ if (std::unique_ptr<RegexNode<T>> node =
+ MappingNode<T>::Instance(p.first, p.second, rules)) {
+ mappings->emplace_back(std::move(node));
+ } else {
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace libtextclassifier2
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_REGEX_BASE_H_
diff --git a/regex-number.cc b/regex-number.cc
new file mode 100644
index 0000000..ed5a899
--- /dev/null
+++ b/regex-number.cc
@@ -0,0 +1,254 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "regex-number.h"
+
+namespace libtextclassifier2 {
+
+std::unique_ptr<DigitNode> DigitNode::Instance(const Rules& rules) {
+ std::unique_ptr<DigitNode> node(new DigitNode());
+ if (!node->Init(rules)) {
+ return nullptr;
+ }
+ return node;
+}
+
+DigitNode::DigitNode() {}
+
+constexpr const char* kDigits = "DIGITS";
+
+bool DigitNode::Init(const Rules& rules) {
+ std::string pattern;
+ if (!rules.RuleForName(kDigits, &pattern)) {
+ TC_LOG(ERROR) << "failed to load pattern";
+ return false;
+ }
+ return RegexNode<int>::Init(pattern);
+}
+
+bool DigitNode::Extract(const std::string& input, Results* result) const {
+ UErrorCode status = U_ZERO_ERROR;
+
+ const icu::UnicodeString unicode_context(input.c_str(), input.size(),
+ "utf-8");
+ const std::unique_ptr<icu::RegexMatcher> matcher(
+ pattern_->matcher(unicode_context, status));
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "failed to compile regex: " << u_errorName(status);
+ return false;
+ }
+
+ while (matcher->find() && U_SUCCESS(status)) {
+ const int start = matcher->start(status);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "failed to bind start: " << u_errorName(status);
+ return false;
+ }
+ const int end = matcher->end(status);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "failed to bind end: " << u_errorName(status);
+ return false;
+ }
+ std::string digit;
+ matcher->group(status).toUTF8String(digit);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "failed to bind digit std::string: "
+ << u_errorName(status);
+ return false;
+ }
+ result->push_back(Result(start, end, stoi(digit)));
+ }
+ return true;
+}
+
+constexpr const char* kSignedDigits = "SIGNEDDIGITS";
+
+std::unique_ptr<SignedDigitNode> SignedDigitNode::Instance(const Rules& rules) {
+ std::unique_ptr<SignedDigitNode> node(new SignedDigitNode());
+ if (!node->Init(rules)) {
+ return nullptr;
+ }
+ return node;
+}
+
+SignedDigitNode::SignedDigitNode() {}
+
+bool SignedDigitNode::Init(const Rules& rules) {
+ std::string pattern;
+ if (!rules.RuleForName(kSignedDigits, &pattern)) {
+ TC_LOG(ERROR) << "failed to load pattern";
+ return false;
+ }
+ return RegexNode<int>::Init(pattern);
+}
+
+bool SignedDigitNode::Extract(const std::string& input, Results* result) const {
+ UErrorCode status = U_ZERO_ERROR;
+
+ const icu::UnicodeString unicode_context(input.c_str(), input.size(),
+ "utf-8");
+ const std::unique_ptr<icu::RegexMatcher> matcher(
+ pattern_->matcher(unicode_context, status));
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "failed to compile regex: " << u_errorName(status);
+ return false;
+ }
+
+ while (matcher->find() && U_SUCCESS(status)) {
+ const int start = matcher->start(status);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "failed to bind start: " << u_errorName(status);
+ return false;
+ }
+ const int end = matcher->end(status);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "failed to bind end: " << u_errorName(status);
+ return false;
+ }
+ std::string digit;
+ matcher->group(status).toUTF8String(digit);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "failed to bind digit std::string: "
+ << u_errorName(status);
+ return false;
+ }
+ result->push_back(Result(start, end, stoi(digit)));
+ }
+ return true;
+}
+
+constexpr const char* kZero = "ZERO";
+constexpr const char* kOne = "ONE";
+constexpr const char* kTwo = "TWO";
+constexpr const char* kThree = "THREE";
+constexpr const char* kFour = "FOUR";
+constexpr const char* kFive = "FIVE";
+constexpr const char* kSix = "SIX";
+constexpr const char* kSeven = "SEVEN";
+constexpr const char* kEight = "EIGHT";
+constexpr const char* kNine = "NINE";
+constexpr const char* kTen = "TEN";
+constexpr const char* kEleven = "ELEVEN";
+constexpr const char* kTwelve = "TWELVE";
+constexpr const char* kThirteen = "THIRTEEN";
+constexpr const char* kFourteen = "FOURTEEN";
+constexpr const char* kFifteen = "FIFTEEN";
+constexpr const char* kSixteen = "SIXTEEN";
+constexpr const char* kSeventeen = "SEVENTEEN";
+constexpr const char* kEighteen = "EIGHTEEN";
+constexpr const char* kNineteen = "NINETEEN";
+constexpr const char* kTwenty = "TWENTY";
+constexpr const char* kThirty = "THIRTY";
+constexpr const char* kForty = "FORTY";
+constexpr const char* kFifty = "FIFTY";
+constexpr const char* kSixty = "SIXTY";
+constexpr const char* kSeventy = "SEVENTY";
+constexpr const char* kEighty = "EIGHTY";
+constexpr const char* kNinety = "NINETY";
+constexpr const char* kHundred = "HUNDRED";
+constexpr const char* kThousand = "THOUSAND";
+
+std::unique_ptr<NumberNode> NumberNode::Instance(const Rules& rules) {
+ const std::vector<std::pair<std::string, int>> name_values = {
+ {kZero, 0}, {kOne, 1}, {kTwo, 2}, {kThree, 3},
+ {kFour, 4}, {kFive, 5}, {kSix, 6}, {kSeven, 7},
+ {kEight, 8}, {kNine, 9}, {kTen, 10}, {kEleven, 11},
+ {kTwelve, 12}, {kThirteen, 13}, {kFourteen, 14}, {kFifteen, 15},
+ {kSixteen, 16}, {kSeventeen, 17}, {kEighteen, 18}, {kNineteen, 19},
+ {kTwenty, 20}, {kThirty, 30}, {kForty, 40}, {kFifty, 50},
+ {kSixty, 60}, {kSeventy, 70}, {kEighty, 80}, {kNinety, 90},
+ {kHundred, 100}, {kThousand, 1000},
+ };
+ std::vector<std::unique_ptr<const RegexNode<int>>> alternatives;
+ if (!BuildMappings<int>(rules, name_values, &alternatives)) {
+ return nullptr;
+ }
+ std::unique_ptr<NumberNode> node(new NumberNode(std::move(alternatives)));
+ if (!node->Init()) {
+ return nullptr;
+ }
+ return node;
+} // namespace libtextclassifier2
+
+bool NumberNode::Init() { return OrRegexNode<int>::Init(); }
+
+NumberNode::NumberNode(
+ std::vector<std::unique_ptr<const RegexNode<int>>> alternatives)
+ : OrRegexNode<int>(std::move(alternatives)) {}
+
+bool NumberNode::Extract(const std::string& input, Results* result) const {
+ UErrorCode status = U_ZERO_ERROR;
+
+ const icu::UnicodeString unicode_context(input.c_str(), input.size(),
+ "utf-8");
+ const std::unique_ptr<icu::RegexMatcher> matcher(
+ RegexNode<int>::pattern_->matcher(unicode_context, status));
+
+ OrRegexNode<int>::Results parts;
+ int start = 0;
+ int end = 0;
+ while (matcher->find() && U_SUCCESS(status)) {
+ std::string group;
+ matcher->group(0, status).toUTF8String(group);
+ int span_start = matcher->start(status);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "failed to demarshall start " << u_errorName(status);
+ return false;
+ }
+ int span_end = matcher->end(status);
+ if (U_FAILURE(status)) {
+ TC_LOG(ERROR) << "failed to demarshall end " << u_errorName(status);
+ }
+ if (span_start < start) {
+ start = span_start;
+ }
+ if (span_end < end) {
+ end = span_end;
+ }
+
+ for (auto& child : alternatives_) {
+ if (child->Matches(group)) {
+ OrRegexNode<int>::Results group_results;
+ if (!child->Extract(group, &group_results)) {
+ return false;
+ }
+ for (OrRegexNode<int>::Result span : group_results) {
+ parts.push_back(span);
+ }
+ }
+ }
+ }
+ int sum = 0;
+ int running_value = -1;
+ // Simple math to make sure we handle written numerical modifiers correctly
+ // so that :="fifty one thousand and one" maps to 51001 and not 50 1 1000 1.
+ for (OrRegexNode<int>::Result part : parts) {
+ if (running_value >= 0) {
+ if (running_value > part.Data()) {
+ sum += running_value;
+ running_value = part.Data();
+ } else {
+ running_value *= part.Data();
+ }
+ } else {
+ running_value = part.Data();
+ }
+ }
+ sum += running_value;
+ result->push_back(Result(start, end, sum));
+ return true;
+}
+} // namespace libtextclassifier2
diff --git a/regex-number.h b/regex-number.h
new file mode 100644
index 0000000..286e57d
--- /dev/null
+++ b/regex-number.h
@@ -0,0 +1,110 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_REGEX_NUMBER_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_REGEX_NUMBER_H_
+
+#include "regex-base.h"
+
+namespace libtextclassifier2 {
+
+extern const char* const kZero;
+extern const char* const kOne;
+extern const char* const kTwo;
+extern const char* const kThree;
+extern const char* const kFour;
+extern const char* const kFive;
+extern const char* const kSix;
+extern const char* const kSeven;
+extern const char* const kEight;
+extern const char* const kNine;
+extern const char* const kTen;
+extern const char* const kEleven;
+extern const char* const kTwelve;
+extern const char* const kThirteen;
+extern const char* const kFourteen;
+extern const char* const kFifteen;
+extern const char* const kSixteen;
+extern const char* const kSeventeen;
+extern const char* const kEighteen;
+extern const char* const kNineteen;
+extern const char* const kTwenty;
+extern const char* const kThirty;
+extern const char* const kForty;
+extern const char* const kFifty;
+extern const char* const kSixty;
+extern const char* const kSeventy;
+extern const char* const kEighty;
+extern const char* const kNinety;
+extern const char* const kHundred;
+extern const char* const kThousand;
+
+extern const char* const kDigits;
+extern const char* const kSignedDigits;
+
+// Class that encapsulates a unsigned integer matching and extract values of
+// SpanResults of type int.
+// Thread-safe.
+class DigitNode : public RegexNode<int> {
+ public:
+ typedef typename RegexNode<int>::Result Result;
+ typedef typename RegexNode<int>::Results Results;
+
+ // Factory method for yielding a pointer to a DigitNode implementation.
+ static std::unique_ptr<DigitNode> Instance(const Rules& rules);
+ bool Extract(const std::string& input, Results* result) const override;
+
+ protected:
+ DigitNode();
+ bool Init(const Rules& rules);
+};
+
+// Class that encapsulates a signed integer matching and extract values of
+// SpanResults of type int.
+// Thread-safe.
+class SignedDigitNode : public RegexNode<int> {
+ public:
+ typedef typename RegexNode<int>::Result Result;
+ typedef typename RegexNode<int>::Results Results;
+
+ // Factory method for yielding a pointer to a DigitNode implementation.
+ static std::unique_ptr<SignedDigitNode> Instance(const Rules& rules);
+ bool Extract(const std::string& input, Results* result) const override;
+
+ protected:
+ SignedDigitNode();
+ bool Init(const Rules& rules);
+};
+
+// Class that encapsulates a simple natural language integer matching and
+// extract values of SpanResults of type int.
+// Thread-safe.
+class NumberNode : public OrRegexNode<int> {
+ public:
+ typedef typename OrRegexNode<int>::Result Result;
+ typedef typename OrRegexNode<int>::Results Results;
+
+ static std::unique_ptr<NumberNode> Instance(const Rules& rules);
+ bool Extract(const std::string& input, Results* result) const override;
+
+ protected:
+ explicit NumberNode(
+ std::vector<std::unique_ptr<const RegexNode<int>>> alternatives);
+ bool Init();
+};
+
+} // namespace libtextclassifier2
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_REGEX_NUMBER_H_
diff --git a/smartselect/cached-features.cc b/smartselect/cached-features.cc
deleted file mode 100644
index c249db9..0000000
--- a/smartselect/cached-features.cc
+++ /dev/null
@@ -1,82 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "smartselect/cached-features.h"
-#include "util/base/logging.h"
-
-namespace libtextclassifier {
-
-void CachedFeatures::Extract(
- const std::vector<std::vector<int>>& sparse_features,
- const std::vector<std::vector<float>>& dense_features,
- const std::function<bool(const std::vector<int>&, const std::vector<float>&,
- float*)>& feature_vector_fn) {
- features_.resize(feature_vector_size_ * tokens_.size());
- for (int i = 0; i < tokens_.size(); ++i) {
- feature_vector_fn(sparse_features[i], dense_features[i],
- features_.data() + i * feature_vector_size_);
- }
-}
-
-bool CachedFeatures::Get(int click_pos, VectorSpan<float>* features,
- VectorSpan<Token>* output_tokens) {
- const int token_start = click_pos - context_size_;
- const int token_end = click_pos + context_size_ + 1;
- if (token_start < 0 || token_end > tokens_.size()) {
- TC_LOG(ERROR) << "Tokens out of range: " << token_start << " " << token_end;
- return false;
- }
-
- *features =
- VectorSpan<float>(features_.begin() + token_start * feature_vector_size_,
- features_.begin() + token_end * feature_vector_size_);
- *output_tokens = VectorSpan<Token>(tokens_.begin() + token_start,
- tokens_.begin() + token_end);
- if (remap_v0_feature_vector_) {
- RemapV0FeatureVector(features);
- }
-
- return true;
-}
-
-void CachedFeatures::RemapV0FeatureVector(VectorSpan<float>* features) {
- if (!remap_v0_feature_vector_) {
- return;
- }
-
- auto it = features->begin();
- int num_suffix_features =
- feature_vector_size_ - remap_v0_chargram_embedding_size_;
- int num_tokens = context_size_ * 2 + 1;
- for (int t = 0; t < num_tokens; ++t) {
- for (int i = 0; i < remap_v0_chargram_embedding_size_; ++i) {
- v0_feature_storage_[t * remap_v0_chargram_embedding_size_ + i] = *it;
- ++it;
- }
- // Rest of the features are the dense features that come to the end.
- for (int i = 0; i < num_suffix_features; ++i) {
- // clang-format off
- v0_feature_storage_[num_tokens * remap_v0_chargram_embedding_size_
- + t * num_suffix_features
- + i] = *it;
- // clang-format on
- ++it;
- }
- }
- *features = VectorSpan<float>(v0_feature_storage_);
-}
-
-} // namespace libtextclassifier
diff --git a/smartselect/cached-features.h b/smartselect/cached-features.h
deleted file mode 100644
index 990233c..0000000
--- a/smartselect/cached-features.h
+++ /dev/null
@@ -1,105 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_
-#define LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_
-
-#include <memory>
-#include <vector>
-
-#include "common/vector-span.h"
-#include "smartselect/types.h"
-
-namespace libtextclassifier {
-
-// Holds state for extracting features across multiple calls and reusing them.
-// Assumes that features for each Token are independent.
-class CachedFeatures {
- public:
- // Extracts the features for the given sequence of tokens.
- // - context_size: Specifies how many tokens to the left, and how many
- // tokens to the right spans the context.
- // - sparse_features, dense_features: Extracted features for each token.
- // - feature_vector_fn: Writes features for given Token to the specified
- // storage.
- // NOTE: The function can assume that the underlying
- // storage is initialized to all zeros.
- // - feature_vector_size: Size of a feature vector for one Token.
- CachedFeatures(VectorSpan<Token> tokens, int context_size,
- const std::vector<std::vector<int>>& sparse_features,
- const std::vector<std::vector<float>>& dense_features,
- const std::function<bool(const std::vector<int>&,
- const std::vector<float>&, float*)>&
- feature_vector_fn,
- int feature_vector_size)
- : tokens_(tokens),
- context_size_(context_size),
- feature_vector_size_(feature_vector_size),
- remap_v0_feature_vector_(false),
- remap_v0_chargram_embedding_size_(-1) {
- Extract(sparse_features, dense_features, feature_vector_fn);
- }
-
- // Gets a VectorSpan with the features for given click position.
- bool Get(int click_pos, VectorSpan<float>* features,
- VectorSpan<Token>* output_tokens);
-
- // Turns on a compatibility mode, which re-maps the extracted features to the
- // v0 feature format (where the dense features were at the end).
- // WARNING: Internally v0_feature_storage_ is used as a backing buffer for
- // VectorSpan<float>, so the output of Extract is valid only until the next
- // call or destruction of the current CachedFeatures object.
- // TODO(zilka): Remove when we'll have retrained models.
- void SetV0FeatureMode(int chargram_embedding_size) {
- remap_v0_feature_vector_ = true;
- remap_v0_chargram_embedding_size_ = chargram_embedding_size;
- v0_feature_storage_.resize(feature_vector_size_ * (context_size_ * 2 + 1));
- }
-
- protected:
- // Extracts features for all tokens and stores them for later retrieval.
- void Extract(const std::vector<std::vector<int>>& sparse_features,
- const std::vector<std::vector<float>>& dense_features,
- const std::function<bool(const std::vector<int>&,
- const std::vector<float>&, float*)>&
- feature_vector_fn);
-
- // Remaps extracted features to V0 feature format. The mapping is using
- // the v0_feature_storage_ as the backing storage for the mapped features.
- // For each token the features consist of:
- // - chargram embeddings
- // - dense features
- // They are concatenated together as [chargram embeddings; dense features]
- // for each token independently.
- // The V0 features require that the chargram embeddings for tokens are
- // concatenated first together, and at the end, the dense features for the
- // tokens are concatenated to it.
- void RemapV0FeatureVector(VectorSpan<float>* features);
-
- private:
- const VectorSpan<Token> tokens_;
- const int context_size_;
- const int feature_vector_size_;
- bool remap_v0_feature_vector_;
- int remap_v0_chargram_embedding_size_;
-
- std::vector<float> features_;
- std::vector<float> v0_feature_storage_;
-};
-
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_
diff --git a/smartselect/cached-features_test.cc b/smartselect/cached-features_test.cc
deleted file mode 100644
index b456816..0000000
--- a/smartselect/cached-features_test.cc
+++ /dev/null
@@ -1,149 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "smartselect/cached-features.h"
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier {
-namespace {
-
-class TestingCachedFeatures : public CachedFeatures {
- public:
- using CachedFeatures::CachedFeatures;
- using CachedFeatures::RemapV0FeatureVector;
-};
-
-TEST(CachedFeaturesTest, Simple) {
- std::vector<Token> tokens;
- tokens.push_back(Token());
- tokens.push_back(Token());
- tokens.push_back(Token("Hello", 0, 1));
- tokens.push_back(Token("World", 1, 2));
- tokens.push_back(Token("today!", 2, 3));
- tokens.push_back(Token());
- tokens.push_back(Token());
-
- std::vector<std::vector<int>> sparse_features(tokens.size());
- for (int i = 0; i < sparse_features.size(); ++i) {
- sparse_features[i].push_back(i);
- }
- std::vector<std::vector<float>> dense_features(tokens.size());
- for (int i = 0; i < dense_features.size(); ++i) {
- dense_features[i].push_back(-i);
- }
-
- TestingCachedFeatures feature_extractor(
- tokens, /*context_size=*/2, sparse_features, dense_features,
- [](const std::vector<int>& sparse_features,
- const std::vector<float>& dense_features, float* features) {
- features[0] = sparse_features[0];
- features[1] = sparse_features[0];
- features[2] = dense_features[0];
- features[3] = dense_features[0];
- features[4] = 123;
- return true;
- },
- 5);
-
- VectorSpan<float> features;
- VectorSpan<Token> output_tokens;
- EXPECT_TRUE(feature_extractor.Get(2, &features, &output_tokens));
- for (int i = 0; i < 5; i++) {
- EXPECT_EQ(features[i * 5 + 0], i) << "Feature " << i;
- EXPECT_EQ(features[i * 5 + 1], i) << "Feature " << i;
- EXPECT_EQ(features[i * 5 + 2], -i) << "Feature " << i;
- EXPECT_EQ(features[i * 5 + 3], -i) << "Feature " << i;
- EXPECT_EQ(features[i * 5 + 4], 123) << "Feature " << i;
- }
-}
-
-TEST(CachedFeaturesTest, InvalidInput) {
- std::vector<Token> tokens;
- tokens.push_back(Token());
- tokens.push_back(Token());
- tokens.push_back(Token("Hello", 0, 1));
- tokens.push_back(Token("World", 1, 2));
- tokens.push_back(Token("today!", 2, 3));
- tokens.push_back(Token());
- tokens.push_back(Token());
-
- std::vector<std::vector<int>> sparse_features(tokens.size());
- std::vector<std::vector<float>> dense_features(tokens.size());
-
- TestingCachedFeatures feature_extractor(
- tokens, /*context_size=*/2, sparse_features, dense_features,
- [](const std::vector<int>& sparse_features,
- const std::vector<float>& dense_features,
- float* features) { return true; },
- /*feature_vector_size=*/5);
-
- VectorSpan<float> features;
- VectorSpan<Token> output_tokens;
- EXPECT_FALSE(feature_extractor.Get(-1000, &features, &output_tokens));
- EXPECT_FALSE(feature_extractor.Get(-1, &features, &output_tokens));
- EXPECT_FALSE(feature_extractor.Get(0, &features, &output_tokens));
- EXPECT_TRUE(feature_extractor.Get(2, &features, &output_tokens));
- EXPECT_TRUE(feature_extractor.Get(4, &features, &output_tokens));
- EXPECT_FALSE(feature_extractor.Get(5, &features, &output_tokens));
- EXPECT_FALSE(feature_extractor.Get(500, &features, &output_tokens));
-}
-
-TEST(CachedFeaturesTest, RemapV0FeatureVector) {
- std::vector<Token> tokens;
- tokens.push_back(Token());
- tokens.push_back(Token());
- tokens.push_back(Token("Hello", 0, 1));
- tokens.push_back(Token("World", 1, 2));
- tokens.push_back(Token("today!", 2, 3));
- tokens.push_back(Token());
- tokens.push_back(Token());
-
- std::vector<std::vector<int>> sparse_features(tokens.size());
- std::vector<std::vector<float>> dense_features(tokens.size());
-
- TestingCachedFeatures feature_extractor(
- tokens, /*context_size=*/2, sparse_features, dense_features,
- [](const std::vector<int>& sparse_features,
- const std::vector<float>& dense_features,
- float* features) { return true; },
- /*feature_vector_size=*/5);
-
- std::vector<float> features_orig(5 * 5);
- for (int i = 0; i < features_orig.size(); i++) {
- features_orig[i] = i;
- }
- VectorSpan<float> features;
-
- feature_extractor.SetV0FeatureMode(0);
- features = VectorSpan<float>(features_orig);
- feature_extractor.RemapV0FeatureVector(&features);
- EXPECT_EQ(
- std::vector<float>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
- 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}),
- std::vector<float>(features.begin(), features.end()));
-
- feature_extractor.SetV0FeatureMode(2);
- features = VectorSpan<float>(features_orig);
- feature_extractor.RemapV0FeatureVector(&features);
- EXPECT_EQ(std::vector<float>({0, 1, 5, 6, 10, 11, 15, 16, 20, 21, 2, 3, 4,
- 7, 8, 9, 12, 13, 14, 17, 18, 19, 22, 23, 24}),
- std::vector<float>(features.begin(), features.end()));
-}
-
-} // namespace
-} // namespace libtextclassifier
diff --git a/smartselect/model-params.cc b/smartselect/model-params.cc
deleted file mode 100644
index 65c4f93..0000000
--- a/smartselect/model-params.cc
+++ /dev/null
@@ -1,108 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "smartselect/model-params.h"
-
-#include "common/memory_image/memory-image-reader.h"
-
-namespace libtextclassifier {
-
-using nlp_core::EmbeddingNetworkProto;
-using nlp_core::MemoryImageReader;
-
-ModelParams* ModelParamsBuilder(
- const void* start, uint64 num_bytes,
- std::shared_ptr<EmbeddingParams> external_embedding_params) {
- MemoryImageReader<EmbeddingNetworkProto> reader(start, num_bytes);
-
- ModelOptions model_options;
- auto model_options_extension_id = model_options_in_embedding_network_proto;
- if (reader.trimmed_proto().HasExtension(model_options_extension_id)) {
- model_options =
- reader.trimmed_proto().GetExtension(model_options_extension_id);
- }
-
- FeatureProcessorOptions feature_processor_options;
- auto feature_processor_extension_id =
- feature_processor_options_in_embedding_network_proto;
- if (reader.trimmed_proto().HasExtension(feature_processor_extension_id)) {
- feature_processor_options =
- reader.trimmed_proto().GetExtension(feature_processor_extension_id);
-
- // If no tokenization codepoint config is present, tokenize on space.
- // TODO(zilka): Remove the default config.
- if (feature_processor_options.tokenization_codepoint_config_size() == 0) {
- TokenizationCodepointRange* config;
- // New line character.
- config = feature_processor_options.add_tokenization_codepoint_config();
- config->set_start(10);
- config->set_end(11);
- config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
-
- // Space character.
- config = feature_processor_options.add_tokenization_codepoint_config();
- config->set_start(32);
- config->set_end(33);
- config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
- }
- } else {
- return nullptr;
- }
-
- SelectionModelOptions selection_options;
- auto selection_options_extension_id =
- selection_model_options_in_embedding_network_proto;
- if (reader.trimmed_proto().HasExtension(selection_options_extension_id)) {
- selection_options =
- reader.trimmed_proto().GetExtension(selection_options_extension_id);
-
- // For backward compatibility with the current models.
- if (!feature_processor_options.ignored_span_boundary_codepoints_size()) {
- *feature_processor_options.mutable_ignored_span_boundary_codepoints() =
- selection_options.deprecated_punctuation_to_strip();
- }
- } else {
- selection_options.set_enforce_symmetry(true);
- selection_options.set_symmetry_context_size(
- feature_processor_options.context_size() * 2);
- }
-
- SharingModelOptions sharing_options;
- auto sharing_options_extension_id =
- sharing_model_options_in_embedding_network_proto;
- if (reader.trimmed_proto().HasExtension(sharing_options_extension_id)) {
- sharing_options =
- reader.trimmed_proto().GetExtension(sharing_options_extension_id);
- } else {
- // Default values when SharingModelOptions is not present.
- sharing_options.set_always_accept_url_hint(true);
- sharing_options.set_always_accept_email_hint(true);
- }
-
- if (!model_options.use_shared_embeddings()) {
- std::shared_ptr<EmbeddingParams> embedding_params(new EmbeddingParams(
- start, num_bytes, feature_processor_options.context_size()));
- return new ModelParams(start, num_bytes, embedding_params,
- selection_options, sharing_options,
- feature_processor_options);
- } else {
- return new ModelParams(
- start, num_bytes, std::move(external_embedding_params),
- selection_options, sharing_options, feature_processor_options);
- }
-}
-
-} // namespace libtextclassifier
diff --git a/smartselect/model-params.h b/smartselect/model-params.h
deleted file mode 100644
index a0d39e6..0000000
--- a/smartselect/model-params.h
+++ /dev/null
@@ -1,152 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Model parameter loading.
-
-#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARAMS_H_
-#define LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARAMS_H_
-
-#include "common/embedding-network.h"
-#include "common/memory_image/embedding-network-params-from-image.h"
-#include "smartselect/text-classification-model.pb.h"
-
-namespace libtextclassifier {
-
-class EmbeddingParams : public nlp_core::EmbeddingNetworkParamsFromImage {
- public:
- EmbeddingParams(const void* start, uint64 num_bytes, int context_size)
- : EmbeddingNetworkParamsFromImage(start, num_bytes),
- context_size_(context_size) {}
-
- int embeddings_size() const override { return context_size_ * 2 + 1; }
-
- int embedding_num_features_size() const override {
- return context_size_ * 2 + 1;
- }
-
- int embedding_num_features(int i) const override { return 1; }
-
- int embeddings_num_rows(int i) const override {
- return EmbeddingNetworkParamsFromImage::embeddings_num_rows(0);
- };
-
- int embeddings_num_cols(int i) const override {
- return EmbeddingNetworkParamsFromImage::embeddings_num_cols(0);
- };
-
- const void* embeddings_weights(int i) const override {
- return EmbeddingNetworkParamsFromImage::embeddings_weights(0);
- };
-
- nlp_core::QuantizationType embeddings_quant_type(int i) const override {
- return EmbeddingNetworkParamsFromImage::embeddings_quant_type(0);
- }
-
- const nlp_core::float16* embeddings_quant_scales(int i) const override {
- return EmbeddingNetworkParamsFromImage::embeddings_quant_scales(0);
- }
-
- private:
- int context_size_;
-};
-
-// Loads and holds the parameters of the inference network.
-//
-// This class overrides a couple of methods of EmbeddingNetworkParamsFromImage
-// because we only have one embedding matrix for all positions of context,
-// whereas the original class would have a separate one for each.
-class ModelParams : public nlp_core::EmbeddingNetworkParamsFromImage {
- public:
- const FeatureProcessorOptions& GetFeatureProcessorOptions() const {
- return feature_processor_options_;
- }
-
- const SelectionModelOptions& GetSelectionModelOptions() const {
- return selection_options_;
- }
-
- const SharingModelOptions& GetSharingModelOptions() const {
- return sharing_options_;
- }
-
- std::shared_ptr<EmbeddingParams> GetEmbeddingParams() const {
- return embedding_params_;
- }
-
- protected:
- int embeddings_size() const override {
- return embedding_params_->embeddings_size();
- }
-
- int embedding_num_features_size() const override {
- return embedding_params_->embedding_num_features_size();
- }
-
- int embedding_num_features(int i) const override {
- return embedding_params_->embedding_num_features(i);
- }
-
- int embeddings_num_rows(int i) const override {
- return embedding_params_->embeddings_num_rows(i);
- };
-
- int embeddings_num_cols(int i) const override {
- return embedding_params_->embeddings_num_cols(i);
- };
-
- const void* embeddings_weights(int i) const override {
- return embedding_params_->embeddings_weights(i);
- };
-
- nlp_core::QuantizationType embeddings_quant_type(int i) const override {
- return embedding_params_->embeddings_quant_type(i);
- }
-
- const nlp_core::float16* embeddings_quant_scales(int i) const override {
- return embedding_params_->embeddings_quant_scales(i);
- }
-
- private:
- friend ModelParams* ModelParamsBuilder(
- const void* start, uint64 num_bytes,
- std::shared_ptr<EmbeddingParams> external_embedding_params);
-
- ModelParams(const void* start, uint64 num_bytes,
- std::shared_ptr<EmbeddingParams> embedding_params,
- const SelectionModelOptions& selection_options,
- const SharingModelOptions& sharing_options,
- const FeatureProcessorOptions& feature_processor_options)
- : EmbeddingNetworkParamsFromImage(start, num_bytes),
- selection_options_(selection_options),
- sharing_options_(sharing_options),
- feature_processor_options_(feature_processor_options),
- context_size_(feature_processor_options_.context_size()),
- embedding_params_(std::move(embedding_params)) {}
-
- SelectionModelOptions selection_options_;
- SharingModelOptions sharing_options_;
- FeatureProcessorOptions feature_processor_options_;
- int context_size_;
- std::shared_ptr<EmbeddingParams> embedding_params_;
-};
-
-ModelParams* ModelParamsBuilder(
- const void* start, uint64 num_bytes,
- std::shared_ptr<EmbeddingParams> external_embedding_params);
-
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARAMS_H_
diff --git a/smartselect/model-parser.cc b/smartselect/model-parser.cc
deleted file mode 100644
index 0cf05e3..0000000
--- a/smartselect/model-parser.cc
+++ /dev/null
@@ -1,91 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "smartselect/model-parser.h"
-#include "util/base/endian.h"
-
-namespace libtextclassifier {
-namespace {
-
-// Small helper class for parsing the merged model format.
-// The merged model consists of interleaved <int32 data_size, char* data>
-// segments.
-class MergedModelParser {
- public:
- MergedModelParser(const void* addr, const int size)
- : addr_(reinterpret_cast<const char*>(addr)), size_(size), pos_(addr_) {}
-
- bool ReadBytesAndAdvance(int num_bytes, const char** result) {
- const char* read_addr = pos_;
- if (Advance(num_bytes)) {
- *result = read_addr;
- return true;
- } else {
- return false;
- }
- }
-
- bool ReadInt32AndAdvance(int* result) {
- const char* read_addr = pos_;
- if (Advance(sizeof(int))) {
- *result =
- LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(read_addr));
- return true;
- } else {
- return false;
- }
- }
-
- bool IsDone() { return pos_ == addr_ + size_; }
-
- private:
- bool Advance(int num_bytes) {
- pos_ += num_bytes;
- return pos_ <= addr_ + size_;
- }
-
- const char* addr_;
- const int size_;
- const char* pos_;
-};
-
-} // namespace
-
-bool ParseMergedModel(const void* addr, const int size,
- const char** selection_model, int* selection_model_length,
- const char** sharing_model, int* sharing_model_length) {
- MergedModelParser parser(addr, size);
-
- if (!parser.ReadInt32AndAdvance(selection_model_length)) {
- return false;
- }
-
- if (!parser.ReadBytesAndAdvance(*selection_model_length, selection_model)) {
- return false;
- }
-
- if (!parser.ReadInt32AndAdvance(sharing_model_length)) {
- return false;
- }
-
- if (!parser.ReadBytesAndAdvance(*sharing_model_length, sharing_model)) {
- return false;
- }
-
- return parser.IsDone();
-}
-
-} // namespace libtextclassifier
diff --git a/smartselect/model-parser.h b/smartselect/model-parser.h
deleted file mode 100644
index 801262f..0000000
--- a/smartselect/model-parser.h
+++ /dev/null
@@ -1,29 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARSER_H_
-#define LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARSER_H_
-
-namespace libtextclassifier {
-
-// Parse a merged model image.
-bool ParseMergedModel(const void* addr, const int size,
- const char** selection_model, int* selection_model_length,
- const char** sharing_model, int* sharing_model_length);
-
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARSER_H_
diff --git a/smartselect/text-classification-model.cc b/smartselect/text-classification-model.cc
deleted file mode 100644
index e7ae09c..0000000
--- a/smartselect/text-classification-model.cc
+++ /dev/null
@@ -1,741 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "smartselect/text-classification-model.h"
-
-#include <cctype>
-#include <cmath>
-#include <iterator>
-#include <numeric>
-
-#include "common/embedding-network.h"
-#include "common/feature-extractor.h"
-#include "common/memory_image/embedding-network-params-from-image.h"
-#include "common/memory_image/memory-image-reader.h"
-#include "common/mmap.h"
-#include "common/softmax.h"
-#include "smartselect/model-parser.h"
-#include "smartselect/text-classification-model.pb.h"
-#include "util/base/logging.h"
-#include "util/utf8/unicodetext.h"
-#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
-#include "unicode/regex.h"
-#include "unicode/uchar.h"
-#endif
-
-namespace libtextclassifier {
-
-using nlp_core::EmbeddingNetwork;
-using nlp_core::EmbeddingNetworkProto;
-using nlp_core::FeatureVector;
-using nlp_core::MemoryImageReader;
-using nlp_core::MmapFile;
-using nlp_core::MmapHandle;
-using nlp_core::ScopedMmap;
-
-namespace {
-
-int CountDigits(const std::string& str, CodepointSpan selection_indices) {
- int count = 0;
- int i = 0;
- const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false);
- for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) {
- if (i >= selection_indices.first && i < selection_indices.second &&
- isdigit(*it)) {
- ++count;
- }
- }
- return count;
-}
-
-std::string ExtractSelection(const std::string& context,
- CodepointSpan selection_indices) {
- const UnicodeText context_unicode =
- UTF8ToUnicodeText(context, /*do_copy=*/false);
- auto selection_begin = context_unicode.begin();
- std::advance(selection_begin, selection_indices.first);
- auto selection_end = context_unicode.begin();
- std::advance(selection_end, selection_indices.second);
- return UnicodeText::UTF8Substring(selection_begin, selection_end);
-}
-
-#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
-bool MatchesRegex(const icu::RegexPattern* regex, const std::string& context) {
- const icu::UnicodeString unicode_context(context.c_str(), context.size(),
- "utf-8");
- UErrorCode status = U_ZERO_ERROR;
- std::unique_ptr<icu::RegexMatcher> matcher(
- regex->matcher(unicode_context, status));
- return matcher->matches(0 /* start */, status);
-}
-#endif
-
-} // namespace
-
-TextClassificationModel::TextClassificationModel(const std::string& path)
- : mmap_(new nlp_core::ScopedMmap(path)) {
- InitFromMmap();
-}
-
-TextClassificationModel::TextClassificationModel(int fd)
- : mmap_(new nlp_core::ScopedMmap(fd)) {
- InitFromMmap();
-}
-
-TextClassificationModel::TextClassificationModel(int fd, int offset, int size)
- : mmap_(new nlp_core::ScopedMmap(fd, offset, size)) {
- InitFromMmap();
-}
-
-TextClassificationModel::TextClassificationModel(const void* addr, int size) {
- initialized_ = LoadModels(addr, size);
- if (!initialized_) {
- TC_LOG(ERROR) << "Failed to load models";
- return;
- }
-}
-
-void TextClassificationModel::InitFromMmap() {
- if (!mmap_->handle().ok()) {
- return;
- }
-
- initialized_ =
- LoadModels(mmap_->handle().start(), mmap_->handle().num_bytes());
- if (!initialized_) {
- TC_LOG(ERROR) << "Failed to load models";
- return;
- }
-}
-
-namespace {
-
-// Converts sparse features vector to nlp_core::FeatureVector.
-void SparseFeaturesToFeatureVector(
- const std::vector<int> sparse_features,
- const nlp_core::NumericFeatureType& feature_type,
- nlp_core::FeatureVector* result) {
- for (int feature_id : sparse_features) {
- const int64 feature_value =
- nlp_core::FloatFeatureValue(feature_id, 1.0 / sparse_features.size())
- .discrete_value;
- result->add(const_cast<nlp_core::NumericFeatureType*>(&feature_type),
- feature_value);
- }
-}
-
-// Returns a function that can be used for mapping sparse and dense features
-// to a float feature vector.
-// NOTE: The network object needs to be available at the time when the returned
-// function object is used.
-FeatureVectorFn CreateFeatureVectorFn(const EmbeddingNetwork& network,
- int sparse_embedding_size) {
- const nlp_core::NumericFeatureType feature_type("chargram_continuous", 0);
- return [&network, sparse_embedding_size, feature_type](
- const std::vector<int>& sparse_features,
- const std::vector<float>& dense_features, float* embedding) {
- nlp_core::FeatureVector feature_vector;
- SparseFeaturesToFeatureVector(sparse_features, feature_type,
- &feature_vector);
-
- if (network.GetEmbedding(feature_vector, 0, embedding)) {
- for (int i = 0; i < dense_features.size(); i++) {
- embedding[sparse_embedding_size + i] = dense_features[i];
- }
- return true;
- } else {
- return false;
- }
- };
-}
-
-} // namespace
-
-void TextClassificationModel::InitializeSharingRegexPatterns(
- const std::vector<SharingModelOptions::RegexPattern>& patterns) {
-#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
- // Initialize pattern recognizers.
- for (const auto& regex_pattern : patterns) {
- UErrorCode status = U_ZERO_ERROR;
- std::unique_ptr<icu::RegexPattern> compiled_pattern(
- icu::RegexPattern::compile(
- icu::UnicodeString(regex_pattern.pattern().c_str(),
- regex_pattern.pattern().size(), "utf-8"),
- 0 /* flags */, status));
- if (U_FAILURE(status)) {
- TC_LOG(WARNING) << "Failed to load pattern" << regex_pattern.pattern();
- } else {
- regex_patterns_.push_back(
- {regex_pattern.collection_name(), std::move(compiled_pattern)});
- }
- }
-#else
- if (!patterns.empty()) {
- TC_LOG(WARNING) << "ICU not supported regexp matchers ignored.";
- }
-#endif
-}
-
-bool TextClassificationModel::LoadModels(const void* addr, int size) {
- const char *selection_model, *sharing_model;
- int selection_model_length, sharing_model_length;
- if (!ParseMergedModel(addr, size, &selection_model, &selection_model_length,
- &sharing_model, &sharing_model_length)) {
- TC_LOG(ERROR) << "Couldn't parse the model.";
- return false;
- }
-
- selection_params_.reset(
- ModelParamsBuilder(selection_model, selection_model_length, nullptr));
- if (!selection_params_.get()) {
- return false;
- }
- selection_options_ = selection_params_->GetSelectionModelOptions();
- selection_network_.reset(new EmbeddingNetwork(selection_params_.get()));
- selection_feature_processor_.reset(
- new FeatureProcessor(selection_params_->GetFeatureProcessorOptions()));
- selection_feature_fn_ = CreateFeatureVectorFn(
- *selection_network_, selection_network_->EmbeddingSize(0));
-
- sharing_params_.reset(
- ModelParamsBuilder(sharing_model, sharing_model_length,
- selection_params_->GetEmbeddingParams()));
- if (!sharing_params_.get()) {
- return false;
- }
- sharing_options_ = selection_params_->GetSharingModelOptions();
- sharing_network_.reset(new EmbeddingNetwork(sharing_params_.get()));
- sharing_feature_processor_.reset(
- new FeatureProcessor(sharing_params_->GetFeatureProcessorOptions()));
- sharing_feature_fn_ = CreateFeatureVectorFn(
- *sharing_network_, sharing_network_->EmbeddingSize(0));
-
- InitializeSharingRegexPatterns(std::vector<SharingModelOptions::RegexPattern>(
- sharing_options_.regex_pattern().begin(),
- sharing_options_.regex_pattern().end()));
-
- return true;
-}
-
-bool ReadSelectionModelOptions(int fd, ModelOptions* model_options) {
- ScopedMmap mmap = ScopedMmap(fd);
- if (!mmap.handle().ok()) {
- TC_LOG(ERROR) << "Can't mmap.";
- return false;
- }
-
- const char *selection_model, *sharing_model;
- int selection_model_length, sharing_model_length;
- if (!ParseMergedModel(mmap.handle().start(), mmap.handle().num_bytes(),
- &selection_model, &selection_model_length,
- &sharing_model, &sharing_model_length)) {
- TC_LOG(ERROR) << "Couldn't parse merged model.";
- return false;
- }
-
- MemoryImageReader<EmbeddingNetworkProto> reader(selection_model,
- selection_model_length);
-
- auto model_options_extension_id = model_options_in_embedding_network_proto;
- if (reader.trimmed_proto().HasExtension(model_options_extension_id)) {
- *model_options =
- reader.trimmed_proto().GetExtension(model_options_extension_id);
- return true;
- } else {
- return false;
- }
-}
-
-EmbeddingNetwork::Vector TextClassificationModel::InferInternal(
- const std::string& context, CodepointSpan span,
- const FeatureProcessor& feature_processor, const EmbeddingNetwork& network,
- const FeatureVectorFn& feature_vector_fn,
- std::vector<CodepointSpan>* selection_label_spans) const {
- std::vector<Token> tokens;
- int click_pos;
- std::unique_ptr<CachedFeatures> cached_features;
- const int embedding_size = network.EmbeddingSize(0);
- if (!feature_processor.ExtractFeatures(
- context, span, /*relative_click_span=*/{0, 0},
- CreateFeatureVectorFn(network, embedding_size),
- embedding_size + feature_processor.DenseFeaturesCount(), &tokens,
- &click_pos, &cached_features)) {
- TC_VLOG(1) << "Could not extract features.";
- return {};
- }
-
- VectorSpan<float> features;
- VectorSpan<Token> output_tokens;
- if (!cached_features->Get(click_pos, &features, &output_tokens)) {
- TC_VLOG(1) << "Could not extract features.";
- return {};
- }
-
- if (selection_label_spans != nullptr) {
- if (!feature_processor.SelectionLabelSpans(output_tokens,
- selection_label_spans)) {
- TC_LOG(ERROR) << "Could not get spans for selection labels.";
- return {};
- }
- }
-
- std::vector<float> scores;
- network.ComputeLogits(features, &scores);
- return scores;
-}
-
-namespace {
-
-// Returns true if given codepoint is contained in the given span in context.
-bool IsCodepointInSpan(const char32 codepoint, const std::string& context,
- const CodepointSpan span) {
- const UnicodeText context_unicode =
- UTF8ToUnicodeText(context, /*do_copy=*/false);
-
- auto begin_it = context_unicode.begin();
- std::advance(begin_it, span.first);
- auto end_it = context_unicode.begin();
- std::advance(end_it, span.second);
-
- return std::find(begin_it, end_it, codepoint) != end_it;
-}
-
-// Returns the first codepoint of the span.
-char32 FirstSpanCodepoint(const std::string& context,
- const CodepointSpan span) {
- const UnicodeText context_unicode =
- UTF8ToUnicodeText(context, /*do_copy=*/false);
-
- auto it = context_unicode.begin();
- std::advance(it, span.first);
- return *it;
-}
-
-// Returns the last codepoint of the span.
-char32 LastSpanCodepoint(const std::string& context, const CodepointSpan span) {
- const UnicodeText context_unicode =
- UTF8ToUnicodeText(context, /*do_copy=*/false);
-
- auto it = context_unicode.begin();
- std::advance(it, span.second - 1);
- return *it;
-}
-
-} // namespace
-
-#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
-
-namespace {
-
-bool IsOpenBracket(const char32 codepoint) {
- return u_getIntPropertyValue(codepoint, UCHAR_BIDI_PAIRED_BRACKET_TYPE) ==
- U_BPT_OPEN;
-}
-
-bool IsClosingBracket(const char32 codepoint) {
- return u_getIntPropertyValue(codepoint, UCHAR_BIDI_PAIRED_BRACKET_TYPE) ==
- U_BPT_CLOSE;
-}
-
-} // namespace
-
-// If the first or the last codepoint of the given span is a bracket, the
-// bracket is stripped if the span does not contain its corresponding paired
-// version.
-CodepointSpan StripUnpairedBrackets(const std::string& context,
- CodepointSpan span) {
- if (context.empty()) {
- return span;
- }
-
- const char32 begin_char = FirstSpanCodepoint(context, span);
-
- const char32 paired_begin_char = u_getBidiPairedBracket(begin_char);
- if (paired_begin_char != begin_char) {
- if (!IsOpenBracket(begin_char) ||
- !IsCodepointInSpan(paired_begin_char, context, span)) {
- ++span.first;
- }
- }
-
- if (span.first == span.second) {
- return span;
- }
-
- const char32 end_char = LastSpanCodepoint(context, span);
- const char32 paired_end_char = u_getBidiPairedBracket(end_char);
- if (paired_end_char != end_char) {
- if (!IsClosingBracket(end_char) ||
- !IsCodepointInSpan(paired_end_char, context, span)) {
- --span.second;
- }
- }
-
- // Should not happen, but let's make sure.
- if (span.first > span.second) {
- TC_LOG(WARNING) << "Inverse indices result: " << span.first << ", "
- << span.second;
- span.second = span.first;
- }
-
- return span;
-}
-#endif
-
-CodepointSpan TextClassificationModel::SuggestSelection(
- const std::string& context, CodepointSpan click_indices) const {
- if (!initialized_) {
- TC_LOG(ERROR) << "Not initialized";
- return click_indices;
- }
-
- const int context_codepoint_size =
- UTF8ToUnicodeText(context, /*do_copy=*/false).size();
-
- if (click_indices.first < 0 || click_indices.second < 0 ||
- click_indices.first >= context_codepoint_size ||
- click_indices.second > context_codepoint_size ||
- click_indices.first >= click_indices.second) {
- TC_VLOG(1) << "Trying to run SuggestSelection with invalid indices: "
- << click_indices.first << " " << click_indices.second;
- return click_indices;
- }
-
- CodepointSpan result;
- if (selection_options_.enforce_symmetry()) {
- result = SuggestSelectionSymmetrical(context, click_indices);
- } else {
- float score;
- std::tie(result, score) = SuggestSelectionInternal(context, click_indices);
- }
-
-#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
- if (selection_options_.strip_unpaired_brackets()) {
- const CodepointSpan stripped_result =
- StripUnpairedBrackets(context, result);
- if (stripped_result.first != stripped_result.second) {
- result = stripped_result;
- }
- }
-#endif
-
- return result;
-}
-
-namespace {
-
-int BestPrediction(const std::vector<float>& scores) {
- if (!scores.empty()) {
- const int prediction =
- std::max_element(scores.begin(), scores.end()) - scores.begin();
- return prediction;
- } else {
- return kInvalidLabel;
- }
-}
-
-std::pair<CodepointSpan, float> BestSelectionSpan(
- CodepointSpan original_click_indices, const std::vector<float>& scores,
- const std::vector<CodepointSpan>& selection_label_spans) {
- const int prediction = BestPrediction(scores);
- if (prediction != kInvalidLabel) {
- std::pair<CodepointIndex, CodepointIndex> selection =
- selection_label_spans[prediction];
-
- if (selection.first == kInvalidIndex || selection.second == kInvalidIndex) {
- TC_VLOG(1) << "Invalid indices predicted, returning input: " << prediction
- << " " << selection.first << " " << selection.second;
- return {original_click_indices, -1.0};
- }
-
- return {{selection.first, selection.second}, scores[prediction]};
- } else {
- TC_LOG(ERROR) << "Returning default selection: scores.size() = "
- << scores.size();
- return {original_click_indices, -1.0};
- }
-}
-
-} // namespace
-
-std::pair<CodepointSpan, float>
-TextClassificationModel::SuggestSelectionInternal(
- const std::string& context, CodepointSpan click_indices) const {
- if (!initialized_) {
- TC_LOG(ERROR) << "Not initialized";
- return {click_indices, -1.0};
- }
-
- std::vector<CodepointSpan> selection_label_spans;
- EmbeddingNetwork::Vector scores = InferInternal(
- context, click_indices, *selection_feature_processor_,
- *selection_network_, selection_feature_fn_, &selection_label_spans);
- scores = nlp_core::ComputeSoftmax(scores);
-
- return BestSelectionSpan(click_indices, scores, selection_label_spans);
-}
-
-// Implements a greedy-search-like algorithm for making selections symmetric.
-//
-// Steps:
-// 1. Get a set of selection proposals from places around the clicked word.
-// 2. For each proposal (going from highest-scoring), check if the tokens that
-// the proposal selects are still free, in which case it claims them, if a
-// proposal that contains the clicked token is found, it is returned as the
-// suggestion.
-//
-// This algorithm should ensure that if a selection is proposed, it does not
-// matter which word of it was tapped - all of them will lead to the same
-// selection.
-CodepointSpan TextClassificationModel::SuggestSelectionSymmetrical(
- const std::string& context, CodepointSpan click_indices) const {
- const int symmetry_context_size = selection_options_.symmetry_context_size();
- std::vector<CodepointSpan> chunks = Chunk(
- context, click_indices, {symmetry_context_size, symmetry_context_size});
- for (const CodepointSpan& chunk : chunks) {
- // If chunk and click indices have an overlap, return the chunk.
- if (!(click_indices.first >= chunk.second ||
- click_indices.second <= chunk.first)) {
- return chunk;
- }
- }
-
- return click_indices;
-}
-
-std::vector<std::pair<std::string, float>>
-TextClassificationModel::ClassifyText(const std::string& context,
- CodepointSpan selection_indices,
- int hint_flags) const {
- if (!initialized_) {
- TC_LOG(ERROR) << "Not initialized";
- return {};
- }
-
- if (std::get<0>(selection_indices) >= std::get<1>(selection_indices)) {
- TC_VLOG(1) << "Trying to run ClassifyText with invalid indices: "
- << std::get<0>(selection_indices) << " "
- << std::get<1>(selection_indices);
- return {};
- }
-
- if (hint_flags & SELECTION_IS_URL &&
- sharing_options_.always_accept_url_hint()) {
- return {{kUrlHintCollection, 1.0}};
- }
-
- if (hint_flags & SELECTION_IS_EMAIL &&
- sharing_options_.always_accept_email_hint()) {
- return {{kEmailHintCollection, 1.0}};
- }
-
- // Check whether any of the regular expressions match.
-#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
- const std::string selection_text =
- ExtractSelection(context, selection_indices);
- for (const CompiledRegexPattern& regex_pattern : regex_patterns_) {
- if (MatchesRegex(regex_pattern.pattern.get(), selection_text)) {
- return {{regex_pattern.collection_name, 1.0}};
- }
- }
-#endif
-
- EmbeddingNetwork::Vector scores =
- InferInternal(context, selection_indices, *sharing_feature_processor_,
- *sharing_network_, sharing_feature_fn_, nullptr);
- if (scores.empty() ||
- scores.size() != sharing_feature_processor_->NumCollections()) {
- TC_VLOG(1) << "Using default class: scores.size() = " << scores.size();
- return {};
- }
-
- scores = nlp_core::ComputeSoftmax(scores);
-
- std::vector<std::pair<std::string, float>> result(scores.size());
- for (int i = 0; i < scores.size(); i++) {
- result[i] = {sharing_feature_processor_->LabelToCollection(i), scores[i]};
- }
- std::sort(result.begin(), result.end(),
- [](const std::pair<std::string, float>& a,
- const std::pair<std::string, float>& b) {
- return a.second > b.second;
- });
-
- // Phone class sanity check.
- if (result.begin()->first == kPhoneCollection) {
- const int digit_count = CountDigits(context, selection_indices);
- if (digit_count < sharing_options_.phone_min_num_digits() ||
- digit_count > sharing_options_.phone_max_num_digits()) {
- return {{kOtherCollection, 1.0}};
- }
- }
-
- return result;
-}
-
-std::vector<CodepointSpan> TextClassificationModel::Chunk(
- const std::string& context, CodepointSpan click_span,
- TokenSpan relative_click_span) const {
- std::unique_ptr<CachedFeatures> cached_features;
- std::vector<Token> tokens;
- int click_index;
- int embedding_size = selection_network_->EmbeddingSize(0);
- if (!selection_feature_processor_->ExtractFeatures(
- context, click_span, relative_click_span, selection_feature_fn_,
- embedding_size + selection_feature_processor_->DenseFeaturesCount(),
- &tokens, &click_index, &cached_features)) {
- TC_VLOG(1) << "Couldn't ExtractFeatures.";
- return {};
- }
-
- int first_token;
- int last_token;
- if (relative_click_span.first == kInvalidIndex ||
- relative_click_span.second == kInvalidIndex) {
- first_token = 0;
- last_token = tokens.size();
- } else {
- first_token = click_index - relative_click_span.first;
- last_token = click_index + relative_click_span.second + 1;
- }
-
- struct SelectionProposal {
- int label;
- int token_index;
- CodepointSpan span;
- float score;
- };
-
- // Scan in the symmetry context for selection span proposals.
- std::vector<SelectionProposal> proposals;
- for (int token_index = first_token; token_index < last_token; ++token_index) {
- if (token_index < 0 || token_index >= tokens.size() ||
- tokens[token_index].is_padding) {
- continue;
- }
-
- float score;
- VectorSpan<float> features;
- VectorSpan<Token> output_tokens;
- std::vector<CodepointSpan> selection_label_spans;
- CodepointSpan span;
- if (cached_features->Get(token_index, &features, &output_tokens) &&
- selection_feature_processor_->SelectionLabelSpans(
- output_tokens, &selection_label_spans)) {
- // Add an implicit proposal for each token to be by itself. Every
- // token should be now represented in the results.
- proposals.push_back(
- SelectionProposal{0, token_index, selection_label_spans[0], 0.0});
-
- std::vector<float> scores;
- selection_network_->ComputeLogits(features, &scores);
-
- scores = nlp_core::ComputeSoftmax(scores);
- std::tie(span, score) = BestSelectionSpan({kInvalidIndex, kInvalidIndex},
- scores, selection_label_spans);
- if (span.first != kInvalidIndex && span.second != kInvalidIndex &&
- score >= 0) {
- const int prediction = BestPrediction(scores);
- proposals.push_back(
- SelectionProposal{prediction, token_index, span, score});
- }
- } else {
- // Add an implicit proposal for each token to be by itself. Every token
- // should be now represented in the results.
- proposals.push_back(SelectionProposal{
- 0,
- token_index,
- {tokens[token_index].start, tokens[token_index].end},
- 0.0});
- }
- }
-
- // Sort selection span proposals by their respective probabilities.
- std::sort(proposals.begin(), proposals.end(),
- [](const SelectionProposal& a, const SelectionProposal& b) {
- return a.score > b.score;
- });
-
- // Go from the highest-scoring proposal and claim tokens. Tokens are marked as
- // claimed by the higher-scoring selection proposals, so that the
- // lower-scoring ones cannot use them. Returns the selection proposal if it
- // contains the clicked token.
- std::vector<CodepointSpan> result;
- std::vector<bool> token_used(tokens.size(), false);
- for (const SelectionProposal& proposal : proposals) {
- const int predicted_label = proposal.label;
- TokenSpan relative_span;
- if (!selection_feature_processor_->LabelToTokenSpan(predicted_label,
- &relative_span)) {
- continue;
- }
- TokenSpan span;
- span.first = proposal.token_index - relative_span.first;
- span.second = proposal.token_index + relative_span.second + 1;
-
- if (span.first != kInvalidIndex && span.second != kInvalidIndex) {
- bool feasible = true;
- for (int i = span.first; i < span.second; i++) {
- if (token_used[i]) {
- feasible = false;
- break;
- }
- }
-
- if (feasible) {
- result.push_back(proposal.span);
- for (int i = span.first; i < span.second; i++) {
- token_used[i] = true;
- }
- }
- }
- }
-
- std::sort(result.begin(), result.end(),
- [](const CodepointSpan& a, const CodepointSpan& b) {
- return a.first < b.first;
- });
-
- return result;
-}
-
-std::vector<TextClassificationModel::AnnotatedSpan>
-TextClassificationModel::Annotate(const std::string& context) const {
- std::vector<CodepointSpan> chunks;
- const UnicodeText context_unicode = UTF8ToUnicodeText(context,
- /*do_copy=*/false);
- for (const UnicodeTextRange& line :
- selection_feature_processor_->SplitContext(context_unicode)) {
- const std::vector<CodepointSpan> local_chunks =
- Chunk(UnicodeText::UTF8Substring(line.first, line.second),
- /*click_span=*/{kInvalidIndex, kInvalidIndex},
- /*relative_click_span=*/{kInvalidIndex, kInvalidIndex});
- const int offset = std::distance(context_unicode.begin(), line.first);
- for (CodepointSpan chunk : local_chunks) {
- chunks.push_back({chunk.first + offset, chunk.second + offset});
- }
- }
-
- std::vector<TextClassificationModel::AnnotatedSpan> result;
- for (const CodepointSpan& chunk : chunks) {
- result.emplace_back();
- result.back().span = chunk;
- result.back().classification = ClassifyText(context, chunk);
- }
- return result;
-}
-
-} // namespace libtextclassifier
diff --git a/smartselect/text-classification-model.h b/smartselect/text-classification-model.h
deleted file mode 100644
index d0df193..0000000
--- a/smartselect/text-classification-model.h
+++ /dev/null
@@ -1,196 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Inference code for the feed-forward text classification models.
-
-#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_
-#define LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_
-
-#include <memory>
-#include <set>
-#include <string>
-
-#include "common/embedding-network.h"
-#include "common/feature-extractor.h"
-#include "common/memory_image/embedding-network-params-from-image.h"
-#include "common/mmap.h"
-#include "smartselect/feature-processor.h"
-#include "smartselect/model-params.h"
-#include "smartselect/text-classification-model.pb.h"
-#include "smartselect/types.h"
-
-namespace libtextclassifier {
-
-// SmartSelection/Sharing feed-forward model.
-class TextClassificationModel {
- public:
- // Represents a result of Annotate call.
- struct AnnotatedSpan {
- // Unicode codepoint indices in the input string.
- CodepointSpan span = {kInvalidIndex, kInvalidIndex};
-
- // Classification result for the span.
- std::vector<std::pair<std::string, float>> classification;
- };
-
- // Loads TextClassificationModel from given file given by an int
- // file descriptor.
- // Offset is byte a position in the file to the beginning of the model data.
- TextClassificationModel(int fd, int offset, int size);
-
- // Same as above but the whole file is mapped and it is assumed the model
- // starts at offset 0.
- explicit TextClassificationModel(int fd);
-
- // Loads TextClassificationModel from given file.
- explicit TextClassificationModel(const std::string& path);
-
- // Loads TextClassificationModel from given location in memory.
- TextClassificationModel(const void* addr, int size);
-
- // Returns true if the model is ready for use.
- bool IsInitialized() { return initialized_; }
-
- // Bit flags for the input selection.
- enum SelectionInputFlags { SELECTION_IS_URL = 0x1, SELECTION_IS_EMAIL = 0x2 };
-
- // Runs inference for given a context and current selection (i.e. index
- // of the first and one past last selected characters (utf8 codepoint
- // offsets)). Returns the indices (utf8 codepoint offsets) of the selection
- // beginning character and one past selection end character.
- // Returns the original click_indices if an error occurs.
- // NOTE: The selection indices are passed in and returned in terms of
- // UTF8 codepoints (not bytes).
- // Requires that the model is a smart selection model.
- CodepointSpan SuggestSelection(const std::string& context,
- CodepointSpan click_indices) const;
-
- // Classifies the selected text given the context string.
- // Requires that the model is a smart sharing model.
- // Returns an empty result if an error occurs.
- std::vector<std::pair<std::string, float>> ClassifyText(
- const std::string& context, CodepointSpan click_indices,
- int input_flags = 0) const;
-
- // Annotates given input text. The annotations should cover the whole input
- // context except for whitespaces, and are sorted by their position in the
- // context string.
- std::vector<AnnotatedSpan> Annotate(const std::string& context) const;
-
- protected:
- // Initializes the model from mmap_ file.
- void InitFromMmap();
-
- // Extracts chunks from the context. The extraction proceeds from the center
- // token determined by click_span and looks at relative_click_span tokens
- // left and right around the click position.
- // If relative_click_span == {kInvalidIndex, kInvalidIndex} then the whole
- // context is considered, regardless of the click_span.
- // Returns the chunks sorted by their position in the context string.
- std::vector<CodepointSpan> Chunk(const std::string& context,
- CodepointSpan click_span,
- TokenSpan relative_click_span) const;
-
- // During evaluation we need access to the feature processor.
- FeatureProcessor* SelectionFeatureProcessor() const {
- return selection_feature_processor_.get();
- }
-
- void InitializeSharingRegexPatterns(
- const std::vector<SharingModelOptions::RegexPattern>& patterns);
-
- // Collection name when url hint is accepted.
- const std::string kUrlHintCollection = "url";
-
- // Collection name when email hint is accepted.
- const std::string kEmailHintCollection = "email";
-
- // Collection name for other.
- const std::string kOtherCollection = "other";
-
- // Collection name for phone.
- const std::string kPhoneCollection = "phone";
-
- SelectionModelOptions selection_options_;
- SharingModelOptions sharing_options_;
-
- private:
-#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
- struct CompiledRegexPattern {
- std::string collection_name;
- std::unique_ptr<icu::RegexPattern> pattern;
- };
-#endif
-
- bool LoadModels(const void* addr, int size);
-
- nlp_core::EmbeddingNetwork::Vector InferInternal(
- const std::string& context, CodepointSpan span,
- const FeatureProcessor& feature_processor,
- const nlp_core::EmbeddingNetwork& network,
- const FeatureVectorFn& feature_vector_fn,
- std::vector<CodepointSpan>* selection_label_spans) const;
-
- // Returns a selection suggestion with a score.
- std::pair<CodepointSpan, float> SuggestSelectionInternal(
- const std::string& context, CodepointSpan click_indices) const;
-
- // Returns a selection suggestion and makes sure it's symmetric. Internally
- // runs several times SuggestSelectionInternal.
- CodepointSpan SuggestSelectionSymmetrical(const std::string& context,
- CodepointSpan click_indices) const;
-
- bool initialized_ = false;
- std::unique_ptr<nlp_core::ScopedMmap> mmap_;
- std::unique_ptr<ModelParams> selection_params_;
- std::unique_ptr<FeatureProcessor> selection_feature_processor_;
- std::unique_ptr<nlp_core::EmbeddingNetwork> selection_network_;
- FeatureVectorFn selection_feature_fn_;
- std::unique_ptr<FeatureProcessor> sharing_feature_processor_;
- std::unique_ptr<ModelParams> sharing_params_;
- std::unique_ptr<nlp_core::EmbeddingNetwork> sharing_network_;
- FeatureVectorFn sharing_feature_fn_;
-#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
- std::vector<CompiledRegexPattern> regex_patterns_;
-#endif
-};
-
-// If the first or the last codepoint of the given span is a bracket, the
-// bracket is stripped if the span does not contain its corresponding paired
-// version.
-CodepointSpan StripUnpairedBrackets(const std::string& context,
- CodepointSpan span);
-
-// Parses the merged image given as a file descriptor, and reads
-// the ModelOptions proto from the selection model.
-bool ReadSelectionModelOptions(int fd, ModelOptions* model_options);
-
-// Pretty-printing function for TextClassificationModel::AnnotatedSpan.
-inline std::ostream& operator<<(
- std::ostream& os, const TextClassificationModel::AnnotatedSpan& span) {
- std::string best_class;
- float best_score = -1;
- if (!span.classification.empty()) {
- best_class = span.classification[0].first;
- best_score = span.classification[0].second;
- }
- return os << "Span(" << span.span.first << ", " << span.span.second << ", "
- << best_class << ", " << best_score << ")";
-}
-
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_
diff --git a/smartselect/text-classification-model.proto b/smartselect/text-classification-model.proto
deleted file mode 100644
index 315e849..0000000
--- a/smartselect/text-classification-model.proto
+++ /dev/null
@@ -1,234 +0,0 @@
-// Copyright (C) 2017 The Android Open Source Project
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Text classification model configuration.
-
-syntax = "proto2";
-option optimize_for = LITE_RUNTIME;
-
-import "external/libtextclassifier/common/embedding-network.proto";
-import "external/libtextclassifier/smartselect/tokenizer.proto";
-
-package libtextclassifier;
-
-// Generic options of a model, non-specific to selection or sharing.
-message ModelOptions {
- // If true, will use embeddings from a different model. This is mainly useful
- // for the Sharing model using the embeddings from the Selection model.
- optional bool use_shared_embeddings = 1;
-
- // Language of the model.
- optional string language = 2;
-
- // Version of the model.
- optional int32 version = 3;
-}
-
-message SelectionModelOptions {
- // A list of Unicode codepoints to strip from predicted selections.
- repeated int32 deprecated_punctuation_to_strip = 1;
-
- // Enforce symmetrical selections.
- optional bool enforce_symmetry = 3;
-
- // Number of inferences made around the click position (to one side), for
- // enforcing symmetry.
- optional int32 symmetry_context_size = 4;
-
- // If true, before the selection is returned, the unpaired brackets contained
- // in the predicted selection are stripped from the both selection ends.
- // The bracket codepoints are defined in the Unicode standard:
- // http://www.unicode.org/Public/UNIDATA/BidiBrackets.txt
- optional bool strip_unpaired_brackets = 5 [default = true];
-
- reserved 2;
-}
-
-message SharingModelOptions {
- // If true, will always return "url" when the url hint is passed in.
- optional bool always_accept_url_hint = 1;
-
- // If true, will always return "email" when the e-mail hint is passed in.
- optional bool always_accept_email_hint = 2;
-
- // Limits for phone numbers.
- optional int32 phone_min_num_digits = 3 [default = 7];
- optional int32 phone_max_num_digits = 4 [default = 15];
-
- // List of regular expression matchers to check.
- message RegexPattern {
- // The name of the collection of a match.
- optional string collection_name = 1;
-
- // The pattern to check.
- optional string pattern = 2;
- }
- repeated RegexPattern regex_pattern = 5;
-}
-
-// Next ID: 41
-message FeatureProcessorOptions {
- // Number of buckets used for hashing charactergrams.
- optional int32 num_buckets = 1 [default = -1];
-
- // Context size defines the number of words to the left and to the right of
- // the selected word to be used as context. For example, if context size is
- // N, then we take N words to the left and N words to the right of the
- // selected word as its context.
- optional int32 context_size = 2 [default = -1];
-
- // Maximum number of words of the context to select in total.
- optional int32 max_selection_span = 3 [default = -1];
-
- // Orders of charactergrams to extract. E.g., 2 means character bigrams, 3
- // character trigrams etc.
- repeated int32 chargram_orders = 4;
-
- // Maximum length of a word, in codepoints.
- optional int32 max_word_length = 21 [default = 20];
-
- // If true, will use the unicode-aware functionality for extracting features.
- optional bool unicode_aware_features = 19 [default = false];
-
- // Whether to extract the token case feature.
- optional bool extract_case_feature = 5 [default = false];
-
- // Whether to extract the selection mask feature.
- optional bool extract_selection_mask_feature = 6 [default = false];
-
- // List of regexps to run over each token. For each regexp, if there is a
- // match, a dense feature of 1.0 is emitted. Otherwise -1.0 is used.
- repeated string regexp_feature = 22;
-
- // Whether to remap all digits to a single number.
- optional bool remap_digits = 20 [default = false];
-
- // Whether to lower-case each token before generating hashgrams.
- optional bool lowercase_tokens = 33;
-
- // If true, the selection classifier output will contain only the selections
- // that are feasible (e.g., those that are shorter than max_selection_span),
- // if false, the output will be a complete cross-product of possible
- // selections to the left and posible selections to the right, including the
- // infeasible ones.
- // NOTE: Exists mainly for compatibility with older models that were trained
- // with the non-reduced output space.
- optional bool selection_reduced_output_space = 8 [default = true];
-
- // Collection names.
- repeated string collections = 9;
-
- // An index of collection in collections to be used if a collection name can't
- // be mapped to an id.
- optional int32 default_collection = 10 [default = -1];
-
- // If true, will split the input by lines, and only use the line that contains
- // the clicked token.
- optional bool only_use_line_with_click = 13 [default = false];
-
- // If true, will split tokens that contain the selection boundary, at the
- // position of the boundary.
- // E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
- optional bool split_tokens_on_selection_boundaries = 14 [default = false];
-
- // Codepoint ranges that determine how different codepoints are tokenized.
- // The ranges must not overlap.
- repeated TokenizationCodepointRange tokenization_codepoint_config = 15;
-
- // Method for selecting the center token.
- enum CenterTokenSelectionMethod {
- DEFAULT_CENTER_TOKEN_METHOD = 0; // Invalid option.
-
- // Use click indices to determine the center token.
- CENTER_TOKEN_FROM_CLICK = 1;
-
- // Use selection indices to get a token range, and select the middle of it
- // as the center token.
- CENTER_TOKEN_MIDDLE_OF_SELECTION = 2;
- }
- optional CenterTokenSelectionMethod center_token_selection_method = 16;
-
- // If true, span boundaries will be snapped to containing tokens and not
- // required to exactly match token boundaries.
- optional bool snap_label_span_boundaries_to_containing_tokens = 18;
-
- // Range of codepoints start - end, where end is exclusive.
- message CodepointRange {
- optional int32 start = 1;
- optional int32 end = 2;
- }
-
- // A set of codepoint ranges supported by the model.
- repeated CodepointRange supported_codepoint_ranges = 23;
-
- // A set of codepoint ranges to use in the mixed tokenization mode to identify
- // stretches of tokens to re-tokenize using the internal tokenizer.
- repeated CodepointRange internal_tokenizer_codepoint_ranges = 34;
-
- // Minimum ratio of supported codepoints in the input context. If the ratio
- // is lower than this, the feature computation will fail.
- optional float min_supported_codepoint_ratio = 24 [default = 0.0];
-
- // Used for versioning the format of features the model expects.
- // - feature_version == 0:
- // For each token the features consist of:
- // - chargram embeddings
- // - dense features
- // Chargram embeddings for tokens are concatenated first together,
- // and at the end, the dense features for the tokens are concatenated
- // to it. So the resulting feature vector has two regions.
- optional int32 feature_version = 25 [default = 0];
-
- // Controls the type of tokenization the model will use for the input text.
- enum TokenizationType {
- INVALID_TOKENIZATION_TYPE = 0;
-
- // Use the internal tokenizer for tokenization.
- INTERNAL_TOKENIZER = 1;
-
- // Use ICU for tokenization.
- ICU = 2;
-
- // First apply ICU tokenization. Then identify stretches of tokens
- // consisting only of codepoints in internal_tokenizer_codepoint_ranges
- // and re-tokenize them using the internal tokenizer.
- MIXED = 3;
- }
- optional TokenizationType tokenization_type = 30
- [default = INTERNAL_TOKENIZER];
- optional bool icu_preserve_whitespace_tokens = 31 [default = false];
-
- // List of codepoints that will be stripped from beginning and end of
- // predicted spans.
- repeated int32 ignored_span_boundary_codepoints = 36;
-
- reserved 7, 11, 12, 26, 27, 28, 29, 32, 35, 39, 40;
-
- // List of allowed charactergrams. The extracted charactergrams are filtered
- // using this list, and charactergrams that are not present are interpreted as
- // out-of-vocabulary.
- // If no allowed_chargrams are specified, all charactergrams are allowed.
- // The field is typed as bytes type to allow non-UTF8 chargrams.
- repeated bytes allowed_chargrams = 38;
-};
-
-extend nlp_core.EmbeddingNetworkProto {
- optional ModelOptions model_options_in_embedding_network_proto = 150063045;
- optional FeatureProcessorOptions
- feature_processor_options_in_embedding_network_proto = 146230910;
- optional SelectionModelOptions
- selection_model_options_in_embedding_network_proto = 148190899;
- optional SharingModelOptions
- sharing_model_options_in_embedding_network_proto = 151445439;
-}
diff --git a/smartselect/text-classification-model_test.cc b/smartselect/text-classification-model_test.cc
deleted file mode 100644
index 5550e53..0000000
--- a/smartselect/text-classification-model_test.cc
+++ /dev/null
@@ -1,440 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "smartselect/text-classification-model.h"
-
-#include <fcntl.h>
-#include <stdio.h>
-#include <fstream>
-#include <iostream>
-#include <memory>
-#include <string>
-
-#include "gtest/gtest.h"
-
-namespace libtextclassifier {
-namespace {
-
-std::string ReadFile(const std::string& file_name) {
- std::ifstream file_stream(file_name);
- return std::string(std::istreambuf_iterator<char>(file_stream), {});
-}
-
-std::string GetModelPath() {
- return TEST_DATA_DIR "smartselection.model";
-}
-
-std::string GetURLRegexPath() {
- return TEST_DATA_DIR "regex_url.txt";
-}
-
-std::string GetEmailRegexPath() {
- return TEST_DATA_DIR "regex_email.txt";
-}
-
-TEST(TextClassificationModelTest, StripUnpairedBrackets) {
- // Stripping brackets strip brackets from length 1 bracket only selections.
- EXPECT_EQ(StripUnpairedBrackets("call me at ) today", {11, 12}),
- std::make_pair(12, 12));
- EXPECT_EQ(StripUnpairedBrackets("call me at ( today", {11, 12}),
- std::make_pair(12, 12));
-}
-
-TEST(TextClassificationModelTest, ReadModelOptions) {
- const std::string model_path = GetModelPath();
- int fd = open(model_path.c_str(), O_RDONLY);
- ModelOptions model_options;
- ASSERT_TRUE(ReadSelectionModelOptions(fd, &model_options));
- close(fd);
-
- EXPECT_EQ("en", model_options.language());
- EXPECT_GT(model_options.version(), 0);
-}
-
-TEST(TextClassificationModelTest, SuggestSelection) {
- const std::string model_path = GetModelPath();
- int fd = open(model_path.c_str(), O_RDONLY);
- std::unique_ptr<TextClassificationModel> model(
- new TextClassificationModel(fd));
- close(fd);
-
- EXPECT_EQ(model->SuggestSelection(
- "this afternoon Barack Obama gave a speech at", {15, 21}),
- std::make_pair(15, 27));
-
- // Try passing whole string.
- // If more than 1 token is specified, we should return back what entered.
- EXPECT_EQ(model->SuggestSelection("350 Third Street, Cambridge", {0, 27}),
- std::make_pair(0, 27));
-
- // Single letter.
- EXPECT_EQ(std::make_pair(0, 1), model->SuggestSelection("a", {0, 1}));
-
- // Single word.
- EXPECT_EQ(std::make_pair(0, 4), model->SuggestSelection("asdf", {0, 4}));
-
- EXPECT_EQ(model->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
- std::make_pair(11, 23));
-
- // Unpaired bracket stripping.
- EXPECT_EQ(
- model->SuggestSelection("call me at (857) 225 3556 today", {11, 16}),
- std::make_pair(11, 25));
- EXPECT_EQ(model->SuggestSelection("call me at (857 225 3556 today", {11, 15}),
- std::make_pair(12, 24));
- EXPECT_EQ(model->SuggestSelection("call me at 857 225 3556) today", {11, 14}),
- std::make_pair(11, 23));
- EXPECT_EQ(
- model->SuggestSelection("call me at )857 225 3556( today", {11, 15}),
- std::make_pair(12, 24));
-
- // If the resulting selection would be empty, the original span is returned.
- EXPECT_EQ(model->SuggestSelection("call me at )( today", {11, 13}),
- std::make_pair(11, 13));
- EXPECT_EQ(model->SuggestSelection("call me at ( today", {11, 12}),
- std::make_pair(11, 12));
- EXPECT_EQ(model->SuggestSelection("call me at ) today", {11, 12}),
- std::make_pair(11, 12));
-}
-
-TEST(TextClassificationModelTest, SuggestSelectionsAreSymmetric) {
- const std::string model_path = GetModelPath();
- int fd = open(model_path.c_str(), O_RDONLY);
- std::unique_ptr<TextClassificationModel> model(
- new TextClassificationModel(fd));
- close(fd);
-
- EXPECT_EQ(std::make_pair(0, 27),
- model->SuggestSelection("350 Third Street, Cambridge", {0, 3}));
- EXPECT_EQ(std::make_pair(0, 27),
- model->SuggestSelection("350 Third Street, Cambridge", {4, 9}));
- EXPECT_EQ(std::make_pair(0, 27),
- model->SuggestSelection("350 Third Street, Cambridge", {10, 16}));
- EXPECT_EQ(std::make_pair(6, 33),
- model->SuggestSelection("a\nb\nc\n350 Third Street, Cambridge",
- {16, 22}));
-}
-
-TEST(TextClassificationModelTest, SuggestSelectionWithNewLine) {
- const std::string model_path = GetModelPath();
- int fd = open(model_path.c_str(), O_RDONLY);
- std::unique_ptr<TextClassificationModel> model(
- new TextClassificationModel(fd));
- close(fd);
-
- std::tuple<int, int> selection;
- selection = model->SuggestSelection("abc\nBarack Obama", {4, 10});
- EXPECT_EQ(4, std::get<0>(selection));
- EXPECT_EQ(16, std::get<1>(selection));
-
- selection = model->SuggestSelection("Barack Obama\nabc", {0, 6});
- EXPECT_EQ(0, std::get<0>(selection));
- EXPECT_EQ(12, std::get<1>(selection));
-}
-
-TEST(TextClassificationModelTest, SuggestSelectionWithPunctuation) {
- const std::string model_path = GetModelPath();
- int fd = open(model_path.c_str(), O_RDONLY);
- std::unique_ptr<TextClassificationModel> model(
- new TextClassificationModel(fd));
- close(fd);
-
- std::tuple<int, int> selection;
-
- // From the right.
- selection = model->SuggestSelection(
- "this afternoon Barack Obama, gave a speech at", {15, 21});
- EXPECT_EQ(15, std::get<0>(selection));
- EXPECT_EQ(27, std::get<1>(selection));
-
- // From the right multiple.
- selection = model->SuggestSelection(
- "this afternoon Barack Obama,.,.,, gave a speech at", {15, 21});
- EXPECT_EQ(15, std::get<0>(selection));
- EXPECT_EQ(27, std::get<1>(selection));
-
- // From the left multiple.
- selection = model->SuggestSelection(
- "this afternoon ,.,.,,Barack Obama gave a speech at", {21, 27});
- EXPECT_EQ(21, std::get<0>(selection));
- EXPECT_EQ(27, std::get<1>(selection));
-
- // From both sides.
- selection = model->SuggestSelection(
- "this afternoon !Barack Obama,- gave a speech at", {16, 22});
- EXPECT_EQ(16, std::get<0>(selection));
- EXPECT_EQ(28, std::get<1>(selection));
-}
-
-class TestingTextClassificationModel
- : public libtextclassifier::TextClassificationModel {
- public:
- explicit TestingTextClassificationModel(int fd)
- : libtextclassifier::TextClassificationModel(fd) {}
-
- using TextClassificationModel::InitializeSharingRegexPatterns;
-
- void DisableClassificationHints() {
- sharing_options_.set_always_accept_url_hint(false);
- sharing_options_.set_always_accept_email_hint(false);
- }
-};
-
-TEST(TextClassificationModelTest, SuggestSelectionNoCrashWithJunk) {
- const std::string model_path = GetModelPath();
- int fd = open(model_path.c_str(), O_RDONLY);
- std::unique_ptr<TextClassificationModel> ff_model(
- new TextClassificationModel(fd));
- close(fd);
-
- std::tuple<int, int> selection;
-
- // Try passing in bunch of invalid selections.
- selection = ff_model->SuggestSelection("", {0, 27});
- // If more than 1 token is specified, we should return back what entered.
- EXPECT_EQ(0, std::get<0>(selection));
- EXPECT_EQ(27, std::get<1>(selection));
-
- selection = ff_model->SuggestSelection("", {-10, 27});
- // If more than 1 token is specified, we should return back what entered.
- EXPECT_EQ(-10, std::get<0>(selection));
- EXPECT_EQ(27, std::get<1>(selection));
-
- selection = ff_model->SuggestSelection("Word 1 2 3 hello!", {0, 27});
- // If more than 1 token is specified, we should return back what entered.
- EXPECT_EQ(0, std::get<0>(selection));
- EXPECT_EQ(27, std::get<1>(selection));
-
- selection = ff_model->SuggestSelection("Word 1 2 3 hello!", {-30, 300});
- // If more than 1 token is specified, we should return back what entered.
- EXPECT_EQ(-30, std::get<0>(selection));
- EXPECT_EQ(300, std::get<1>(selection));
-
- selection = ff_model->SuggestSelection("Word 1 2 3 hello!", {-10, -1});
- // If more than 1 token is specified, we should return back what entered.
- EXPECT_EQ(-10, std::get<0>(selection));
- EXPECT_EQ(-1, std::get<1>(selection));
-
- selection = ff_model->SuggestSelection("Word 1 2 3 hello!", {100, 17});
- // If more than 1 token is specified, we should return back what entered.
- EXPECT_EQ(100, std::get<0>(selection));
- EXPECT_EQ(17, std::get<1>(selection));
-}
-
-namespace {
-
-std::string FindBestResult(std::vector<std::pair<std::string, float>> results) {
- if (results.empty()) {
- return "<INVALID RESULTS>";
- }
-
- std::sort(results.begin(), results.end(),
- [](const std::pair<std::string, float> a,
- const std::pair<std::string, float> b) {
- return a.second > b.second;
- });
- return results[0].first;
-}
-
-} // namespace
-
-TEST(TextClassificationModelTest, ClassifyText) {
- const std::string model_path = GetModelPath();
- int fd = open(model_path.c_str(), O_RDONLY);
- std::unique_ptr<TestingTextClassificationModel> model(
- new TestingTextClassificationModel(fd));
- close(fd);
-
- model->DisableClassificationHints();
- EXPECT_EQ("other",
- FindBestResult(model->ClassifyText(
- "this afternoon Barack Obama gave a speech at", {15, 27})));
- EXPECT_EQ("other",
- FindBestResult(model->ClassifyText("you@android.com", {0, 15})));
- EXPECT_EQ("other", FindBestResult(model->ClassifyText(
- "Contact me at you@android.com", {14, 29})));
- EXPECT_EQ("phone", FindBestResult(model->ClassifyText(
- "Call me at (800) 123-456 today", {11, 24})));
- EXPECT_EQ("other", FindBestResult(model->ClassifyText(
- "Visit www.google.com every today!", {6, 20})));
-
- // More lines.
- EXPECT_EQ("other",
- FindBestResult(model->ClassifyText(
- "this afternoon Barack Obama gave a speech at|Visit "
- "www.google.com every today!|Call me at (800) 123-456 today.",
- {15, 27})));
- EXPECT_EQ("other",
- FindBestResult(model->ClassifyText(
- "this afternoon Barack Obama gave a speech at|Visit "
- "www.google.com every today!|Call me at (800) 123-456 today.",
- {51, 65})));
- EXPECT_EQ("phone",
- FindBestResult(model->ClassifyText(
- "this afternoon Barack Obama gave a speech at|Visit "
- "www.google.com every today!|Call me at (800) 123-456 today.",
- {90, 103})));
-
- // Single word.
- EXPECT_EQ("other", FindBestResult(model->ClassifyText("obama", {0, 5})));
- EXPECT_EQ("other", FindBestResult(model->ClassifyText("asdf", {0, 4})));
- EXPECT_EQ("<INVALID RESULTS>",
- FindBestResult(model->ClassifyText("asdf", {0, 0})));
-
- // Junk.
- EXPECT_EQ("<INVALID RESULTS>",
- FindBestResult(model->ClassifyText("", {0, 0})));
- EXPECT_EQ("<INVALID RESULTS>", FindBestResult(model->ClassifyText(
- "a\n\n\n\nx x x\n\n\n\n\n\n", {1, 5})));
-}
-
-TEST(TextClassificationModelTest, ClassifyTextWithHints) {
- const std::string model_path = GetModelPath();
- int fd = open(model_path.c_str(), O_RDONLY);
- std::unique_ptr<TestingTextClassificationModel> model(
- new TestingTextClassificationModel(fd));
- close(fd);
-
- // When EMAIL hint is passed, the result should be email.
- EXPECT_EQ("email",
- FindBestResult(model->ClassifyText(
- "x", {0, 1}, TextClassificationModel::SELECTION_IS_EMAIL)));
- // When URL hint is passed, the result should be email.
- EXPECT_EQ("url",
- FindBestResult(model->ClassifyText(
- "x", {0, 1}, TextClassificationModel::SELECTION_IS_URL)));
- // When both hints are passed, the result should be url (as it's probably
- // better to let Chrome handle this case).
- EXPECT_EQ("url", FindBestResult(model->ClassifyText(
- "x", {0, 1},
- TextClassificationModel::SELECTION_IS_EMAIL |
- TextClassificationModel::SELECTION_IS_URL)));
-
- // With disabled hints, we should get the same prediction regardless of the
- // hint.
- model->DisableClassificationHints();
- EXPECT_EQ(model->ClassifyText("x", {0, 1}, 0),
- model->ClassifyText("x", {0, 1},
- TextClassificationModel::SELECTION_IS_EMAIL));
-
- EXPECT_EQ(model->ClassifyText("x", {0, 1}, 0),
- model->ClassifyText("x", {0, 1},
- TextClassificationModel::SELECTION_IS_URL));
-}
-
-TEST(TextClassificationModelTest, PhoneFiltering) {
- const std::string model_path = GetModelPath();
- int fd = open(model_path.c_str(), O_RDONLY);
- std::unique_ptr<TestingTextClassificationModel> model(
- new TestingTextClassificationModel(fd));
- close(fd);
-
- EXPECT_EQ("phone", FindBestResult(model->ClassifyText("phone: (123) 456 789",
- {7, 20}, 0)));
- EXPECT_EQ("phone", FindBestResult(model->ClassifyText(
- "phone: (123) 456 789,0001112", {7, 25}, 0)));
- EXPECT_EQ("other", FindBestResult(model->ClassifyText(
- "phone: (123) 456 789,0001112", {7, 28}, 0)));
-}
-
-TEST(TextClassificationModelTest, Annotate) {
- const std::string model_path = GetModelPath();
- int fd = open(model_path.c_str(), O_RDONLY);
- std::unique_ptr<TestingTextClassificationModel> model(
- new TestingTextClassificationModel(fd));
- close(fd);
-
- std::string test_string =
- "& saw Barak Obama today .. 350 Third Street, Cambridge\nand my phone "
- "number is 853 225-3556.";
- std::vector<TextClassificationModel::AnnotatedSpan> result =
- model->Annotate(test_string);
-
- std::vector<TextClassificationModel::AnnotatedSpan> expected;
- expected.emplace_back();
- expected.back().span = {0, 0};
- expected.emplace_back();
- expected.back().span = {2, 5};
- expected.back().classification.push_back({"other", 1.0});
- expected.emplace_back();
- expected.back().span = {6, 17};
- expected.back().classification.push_back({"other", 1.0});
- expected.emplace_back();
- expected.back().span = {18, 23};
- expected.back().classification.push_back({"other", 1.0});
- expected.emplace_back();
- expected.back().span = {24, 24};
- expected.emplace_back();
- expected.back().span = {27, 54};
- expected.back().classification.push_back({"address", 1.0});
- expected.emplace_back();
- expected.back().span = {55, 58};
- expected.back().classification.push_back({"other", 1.0});
- expected.emplace_back();
- expected.back().span = {59, 61};
- expected.back().classification.push_back({"other", 1.0});
- expected.emplace_back();
- expected.back().span = {62, 74};
- expected.back().classification.push_back({"other", 1.0});
- expected.emplace_back();
- expected.back().span = {75, 77};
- expected.back().classification.push_back({"other", 1.0});
- expected.emplace_back();
- expected.back().span = {78, 90};
- expected.back().classification.push_back({"phone", 1.0});
-
- EXPECT_EQ(result.size(), expected.size());
- for (int i = 0; i < expected.size(); ++i) {
- EXPECT_EQ(result[i].span, expected[i].span) << result[i];
- if (!expected[i].classification.empty()) {
- EXPECT_GT(result[i].classification.size(), 0);
- EXPECT_EQ(result[i].classification[0].first,
- expected[i].classification[0].first)
- << result[i];
- }
- }
-}
-
-TEST(TextClassificationModelTest, URLEmailRegex) {
- const std::string model_path = GetModelPath();
- int fd = open(model_path.c_str(), O_RDONLY);
- std::unique_ptr<TestingTextClassificationModel> model(
- new TestingTextClassificationModel(fd));
- close(fd);
-
- SharingModelOptions options;
- SharingModelOptions::RegexPattern* email_pattern =
- options.add_regex_pattern();
- email_pattern->set_collection_name("email");
- email_pattern->set_pattern(ReadFile(GetEmailRegexPath()));
- SharingModelOptions::RegexPattern* url_pattern = options.add_regex_pattern();
- url_pattern->set_collection_name("url");
- url_pattern->set_pattern(ReadFile(GetURLRegexPath()));
-
- // TODO(b/69538802): Modify directly the model image instead.
- model->InitializeSharingRegexPatterns(
- {options.regex_pattern().begin(), options.regex_pattern().end()});
-
- EXPECT_EQ("url", FindBestResult(model->ClassifyText(
- "Visit www.google.com every today!", {6, 20})));
- EXPECT_EQ("email", FindBestResult(model->ClassifyText(
- "My email: asdf@something.cz", {10, 27})));
- EXPECT_EQ("url", FindBestResult(model->ClassifyText(
- "Login: http://asdf@something.cz", {7, 31})));
-}
-
-} // namespace
-} // namespace libtextclassifier
diff --git a/smartselect/tokenizer.cc b/smartselect/tokenizer.cc
deleted file mode 100644
index 2489a61..0000000
--- a/smartselect/tokenizer.cc
+++ /dev/null
@@ -1,95 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "smartselect/tokenizer.h"
-
-#include <algorithm>
-
-#include "util/strings/utf8.h"
-#include "util/utf8/unicodetext.h"
-
-namespace libtextclassifier {
-
-Tokenizer::Tokenizer(
- const std::vector<TokenizationCodepointRange>& codepoint_ranges)
- : codepoint_ranges_(codepoint_ranges) {
- std::sort(codepoint_ranges_.begin(), codepoint_ranges_.end(),
- [](const TokenizationCodepointRange& a,
- const TokenizationCodepointRange& b) {
- return a.start() < b.start();
- });
-}
-
-TokenizationCodepointRange::Role Tokenizer::FindTokenizationRole(
- int codepoint) const {
- auto it = std::lower_bound(
- codepoint_ranges_.begin(), codepoint_ranges_.end(), codepoint,
- [](const TokenizationCodepointRange& range, int codepoint) {
- // This function compares range with the codepoint for the purpose of
- // finding the first greater or equal range. Because of the use of
- // std::lower_bound it needs to return true when range < codepoint;
- // the first time it will return false the lower bound is found and
- // returned.
- //
- // It might seem weird that the condition is range.end <= codepoint
- // here but when codepoint == range.end it means it's actually just
- // outside of the range, thus the range is less than the codepoint.
- return range.end() <= codepoint;
- });
- if (it != codepoint_ranges_.end() && it->start() <= codepoint &&
- it->end() > codepoint) {
- return it->role();
- } else {
- return TokenizationCodepointRange::DEFAULT_ROLE;
- }
-}
-
-std::vector<Token> Tokenizer::Tokenize(const std::string& utf8_text) const {
- UnicodeText context_unicode = UTF8ToUnicodeText(utf8_text, /*do_copy=*/false);
-
- std::vector<Token> result;
- Token new_token("", 0, 0);
- int codepoint_index = 0;
- for (auto it = context_unicode.begin(); it != context_unicode.end();
- ++it, ++codepoint_index) {
- TokenizationCodepointRange::Role role = FindTokenizationRole(*it);
- if (role & TokenizationCodepointRange::SPLIT_BEFORE) {
- if (!new_token.value.empty()) {
- result.push_back(new_token);
- }
- new_token = Token("", codepoint_index, codepoint_index);
- }
- if (!(role & TokenizationCodepointRange::DISCARD_CODEPOINT)) {
- new_token.value += std::string(
- it.utf8_data(),
- it.utf8_data() + GetNumBytesForNonZeroUTF8Char(it.utf8_data()));
- ++new_token.end;
- }
- if (role & TokenizationCodepointRange::SPLIT_AFTER) {
- if (!new_token.value.empty()) {
- result.push_back(new_token);
- }
- new_token = Token("", codepoint_index + 1, codepoint_index + 1);
- }
- }
- if (!new_token.value.empty()) {
- result.push_back(new_token);
- }
-
- return result;
-}
-
-} // namespace libtextclassifier
diff --git a/smartselect/tokenizer.h b/smartselect/tokenizer.h
deleted file mode 100644
index 4eb78f9..0000000
--- a/smartselect/tokenizer.h
+++ /dev/null
@@ -1,52 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_TOKENIZER_H_
-#define LIBTEXTCLASSIFIER_SMARTSELECT_TOKENIZER_H_
-
-#include <string>
-#include <vector>
-
-#include "smartselect/tokenizer.pb.h"
-#include "smartselect/types.h"
-
-namespace libtextclassifier {
-
-// Tokenizer splits the input string into a sequence of tokens, according to the
-// configuration.
-class Tokenizer {
- public:
- explicit Tokenizer(
- const std::vector<TokenizationCodepointRange>& codepoint_ranges);
-
- // Tokenizes the input string using the selected tokenization method.
- std::vector<Token> Tokenize(const std::string& utf8_text) const;
-
- protected:
- // Finds the tokenization role for given codepoint.
- // If the character is not found returns DEFAULT_ROLE.
- // Internally uses binary search so should be O(log(# of codepoint_ranges)).
- TokenizationCodepointRange::Role FindTokenizationRole(int codepoint) const;
-
- private:
- // Codepoint ranges that determine how different codepoints are tokenized.
- // The ranges must not overlap.
- std::vector<TokenizationCodepointRange> codepoint_ranges_;
-};
-
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_SMARTSELECT_TOKENIZER_H_
diff --git a/smartselect/tokenizer.proto b/smartselect/tokenizer.proto
deleted file mode 100644
index 8e78970..0000000
--- a/smartselect/tokenizer.proto
+++ /dev/null
@@ -1,48 +0,0 @@
-// Copyright (C) 2017 The Android Open Source Project
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-syntax = "proto2";
-option optimize_for = LITE_RUNTIME;
-
-package libtextclassifier;
-
-// Represents a codepoint range [start, end) with its role for tokenization.
-message TokenizationCodepointRange {
- optional int32 start = 1;
- optional int32 end = 2;
-
- // Role of the codepoints in the range.
- enum Role {
- // Concatenates the codepoint to the current run of codepoints.
- DEFAULT_ROLE = 0;
-
- // Splits a run of codepoints before the current codepoint.
- SPLIT_BEFORE = 0x1;
-
- // Splits a run of codepoints after the current codepoint.
- SPLIT_AFTER = 0x2;
-
- // Discards the codepoint.
- DISCARD_CODEPOINT = 0x4;
-
- // Common values:
- // Splits on the characters and discards them. Good e.g. for the space
- // character.
- WHITESPACE_SEPARATOR = 0x7;
- // Each codepoint will be a separate token. Good e.g. for Chinese
- // characters.
- TOKEN_SEPARATOR = 0x3;
- }
- optional Role role = 3;
-}
diff --git a/smartselect/tokenizer_test.cc b/smartselect/tokenizer_test.cc
deleted file mode 100644
index cdb90a9..0000000
--- a/smartselect/tokenizer_test.cc
+++ /dev/null
@@ -1,261 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "smartselect/tokenizer.h"
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier {
-namespace {
-
-using testing::ElementsAreArray;
-
-class TestingTokenizer : public Tokenizer {
- public:
- explicit TestingTokenizer(
- const std::vector<TokenizationCodepointRange>& codepoint_range_configs)
- : Tokenizer(codepoint_range_configs) {}
-
- TokenizationCodepointRange::Role TestFindTokenizationRole(int c) const {
- return FindTokenizationRole(c);
- }
-};
-
-TEST(TokenizerTest, FindTokenizationRole) {
- std::vector<TokenizationCodepointRange> configs;
- TokenizationCodepointRange* config;
-
- configs.emplace_back();
- config = &configs.back();
- config->set_start(0);
- config->set_end(10);
- config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR);
-
- configs.emplace_back();
- config = &configs.back();
- config->set_start(32);
- config->set_end(33);
- config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
-
- configs.emplace_back();
- config = &configs.back();
- config->set_start(1234);
- config->set_end(12345);
- config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR);
-
- TestingTokenizer tokenizer(configs);
-
- // Test hits to the first group.
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(0),
- TokenizationCodepointRange::TOKEN_SEPARATOR);
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(5),
- TokenizationCodepointRange::TOKEN_SEPARATOR);
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(10),
- TokenizationCodepointRange::DEFAULT_ROLE);
-
- // Test a hit to the second group.
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(31),
- TokenizationCodepointRange::DEFAULT_ROLE);
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(32),
- TokenizationCodepointRange::WHITESPACE_SEPARATOR);
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(33),
- TokenizationCodepointRange::DEFAULT_ROLE);
-
- // Test hits to the third group.
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(1233),
- TokenizationCodepointRange::DEFAULT_ROLE);
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(1234),
- TokenizationCodepointRange::TOKEN_SEPARATOR);
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(12344),
- TokenizationCodepointRange::TOKEN_SEPARATOR);
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(12345),
- TokenizationCodepointRange::DEFAULT_ROLE);
-
- // Test a hit outside.
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(99),
- TokenizationCodepointRange::DEFAULT_ROLE);
-}
-
-TEST(TokenizerTest, TokenizeOnSpace) {
- std::vector<TokenizationCodepointRange> configs;
- TokenizationCodepointRange* config;
-
- configs.emplace_back();
- config = &configs.back();
- // Space character.
- config->set_start(32);
- config->set_end(33);
- config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
-
- TestingTokenizer tokenizer(configs);
- std::vector<Token> tokens = tokenizer.Tokenize("Hello world!");
-
- EXPECT_THAT(tokens,
- ElementsAreArray({Token("Hello", 0, 5), Token("world!", 6, 12)}));
-}
-
-TEST(TokenizerTest, TokenizeComplex) {
- std::vector<TokenizationCodepointRange> configs;
- TokenizationCodepointRange* config;
-
- // Source: http://www.unicode.org/Public/10.0.0/ucd/Blocks-10.0.0d1.txt
- // Latin - cyrilic.
- // 0000..007F; Basic Latin
- // 0080..00FF; Latin-1 Supplement
- // 0100..017F; Latin Extended-A
- // 0180..024F; Latin Extended-B
- // 0250..02AF; IPA Extensions
- // 02B0..02FF; Spacing Modifier Letters
- // 0300..036F; Combining Diacritical Marks
- // 0370..03FF; Greek and Coptic
- // 0400..04FF; Cyrillic
- // 0500..052F; Cyrillic Supplement
- // 0530..058F; Armenian
- // 0590..05FF; Hebrew
- // 0600..06FF; Arabic
- // 0700..074F; Syriac
- // 0750..077F; Arabic Supplement
- configs.emplace_back();
- config = &configs.back();
- config->set_start(0);
- config->set_end(32);
- config->set_role(TokenizationCodepointRange::DEFAULT_ROLE);
- configs.emplace_back();
- config = &configs.back();
- config->set_start(32);
- config->set_end(33);
- config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
- configs.emplace_back();
- config = &configs.back();
- config->set_start(33);
- config->set_end(0x77F + 1);
- config->set_role(TokenizationCodepointRange::DEFAULT_ROLE);
-
- // CJK
- // 2E80..2EFF; CJK Radicals Supplement
- // 3000..303F; CJK Symbols and Punctuation
- // 3040..309F; Hiragana
- // 30A0..30FF; Katakana
- // 3100..312F; Bopomofo
- // 3130..318F; Hangul Compatibility Jamo
- // 3190..319F; Kanbun
- // 31A0..31BF; Bopomofo Extended
- // 31C0..31EF; CJK Strokes
- // 31F0..31FF; Katakana Phonetic Extensions
- // 3200..32FF; Enclosed CJK Letters and Months
- // 3300..33FF; CJK Compatibility
- // 3400..4DBF; CJK Unified Ideographs Extension A
- // 4DC0..4DFF; Yijing Hexagram Symbols
- // 4E00..9FFF; CJK Unified Ideographs
- // A000..A48F; Yi Syllables
- // A490..A4CF; Yi Radicals
- // A4D0..A4FF; Lisu
- // A500..A63F; Vai
- // F900..FAFF; CJK Compatibility Ideographs
- // FE30..FE4F; CJK Compatibility Forms
- // 20000..2A6DF; CJK Unified Ideographs Extension B
- // 2A700..2B73F; CJK Unified Ideographs Extension C
- // 2B740..2B81F; CJK Unified Ideographs Extension D
- // 2B820..2CEAF; CJK Unified Ideographs Extension E
- // 2CEB0..2EBEF; CJK Unified Ideographs Extension F
- // 2F800..2FA1F; CJK Compatibility Ideographs Supplement
- configs.emplace_back();
- config = &configs.back();
- config->set_start(0x2E80);
- config->set_end(0x2EFF + 1);
- config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR);
- configs.emplace_back();
- config = &configs.back();
- config->set_start(0x3000);
- config->set_end(0xA63F + 1);
- config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR);
- configs.emplace_back();
- config = &configs.back();
- config->set_start(0xF900);
- config->set_end(0xFAFF + 1);
- config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR);
- configs.emplace_back();
- config = &configs.back();
- config->set_start(0xFE30);
- config->set_end(0xFE4F + 1);
- config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR);
- configs.emplace_back();
- config = &configs.back();
- config->set_start(0x20000);
- config->set_end(0x2A6DF + 1);
- config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR);
- configs.emplace_back();
- config = &configs.back();
- config->set_start(0x2A700);
- config->set_end(0x2B73F + 1);
- config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR);
- configs.emplace_back();
- config = &configs.back();
- config->set_start(0x2B740);
- config->set_end(0x2B81F + 1);
- config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR);
- configs.emplace_back();
- config = &configs.back();
- config->set_start(0x2B820);
- config->set_end(0x2CEAF + 1);
- config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR);
- configs.emplace_back();
- config = &configs.back();
- config->set_start(0x2CEB0);
- config->set_end(0x2EBEF + 1);
- config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR);
- configs.emplace_back();
- config = &configs.back();
- config->set_start(0x2F800);
- config->set_end(0x2FA1F + 1);
- config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR);
-
- // Thai.
- // 0E00..0E7F; Thai
- configs.emplace_back();
- config = &configs.back();
- config->set_start(0x0E00);
- config->set_end(0x0E7F + 1);
- config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR);
-
- Tokenizer tokenizer(configs);
- std::vector<Token> tokens;
-
- tokens = tokenizer.Tokenize(
- "問少目木輸走猶術権自京門録球変。細開括省用掲情結傍走愛明氷。");
- EXPECT_EQ(tokens.size(), 30);
-
- tokens = tokenizer.Tokenize("問少目 hello 木輸ยามきゃ");
- // clang-format off
- EXPECT_THAT(
- tokens,
- ElementsAreArray({Token("問", 0, 1),
- Token("少", 1, 2),
- Token("目", 2, 3),
- Token("hello", 4, 9),
- Token("木", 10, 11),
- Token("輸", 11, 12),
- Token("ย", 12, 13),
- Token("า", 13, 14),
- Token("ม", 14, 15),
- Token("き", 15, 16),
- Token("ゃ", 16, 17)}));
- // clang-format on
-}
-
-} // namespace
-} // namespace libtextclassifier
diff --git a/smartselect/types.h b/smartselect/types.h
deleted file mode 100644
index 443e3ac..0000000
--- a/smartselect/types.h
+++ /dev/null
@@ -1,80 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_TYPES_H_
-#define LIBTEXTCLASSIFIER_SMARTSELECT_TYPES_H_
-
-#include <ostream>
-#include <string>
-#include <utility>
-
-namespace libtextclassifier {
-
-constexpr int kInvalidIndex = -1;
-
-// Index for a 0-based array of tokens.
-using TokenIndex = int;
-
-// Index for a 0-based array of codepoints.
-using CodepointIndex = int;
-
-// Marks a span in a sequence of codepoints. The first element is the index of
-// the first codepoint of the span, and the second element is the index of the
-// codepoint one past the end of the span.
-using CodepointSpan = std::pair<CodepointIndex, CodepointIndex>;
-
-// Marks a span in a sequence of tokens. The first element is the index of the
-// first token in the span, and the second element is the index of the token one
-// past the end of the span.
-using TokenSpan = std::pair<TokenIndex, TokenIndex>;
-
-// Token holds a token, its position in the original string and whether it was
-// part of the input span.
-struct Token {
- std::string value;
- CodepointIndex start;
- CodepointIndex end;
-
- // Whether the token is a padding token.
- bool is_padding;
-
- // Default constructor constructs the padding-token.
- Token()
- : value(""), start(kInvalidIndex), end(kInvalidIndex), is_padding(true) {}
-
- Token(const std::string& arg_value, CodepointIndex arg_start,
- CodepointIndex arg_end)
- : value(arg_value), start(arg_start), end(arg_end), is_padding(false) {}
-
- bool operator==(const Token& other) const {
- return value == other.value && start == other.start && end == other.end &&
- is_padding == other.is_padding;
- }
-
- bool IsContainedInSpan(CodepointSpan span) const {
- return start >= span.first && end <= span.second;
- }
-};
-
-// Pretty-printing function for Token.
-inline std::ostream& operator<<(std::ostream& os, const Token& token) {
- return os << "Token(\"" << token.value << "\", " << token.start << ", "
- << token.end << ", is_padding=" << token.is_padding << ")";
-}
-
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_SMARTSELECT_TYPES_H_
diff --git a/strip-unpaired-brackets.cc b/strip-unpaired-brackets.cc
new file mode 100644
index 0000000..f813e6b
--- /dev/null
+++ b/strip-unpaired-brackets.cc
@@ -0,0 +1,105 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "strip-unpaired-brackets.h"
+
+#include <iterator>
+
+#include "util/base/logging.h"
+#include "util/utf8/unicodetext.h"
+
+namespace libtextclassifier2 {
+namespace {
+
+// Returns true if given codepoint is contained in the given span in context.
+bool IsCodepointInSpan(const char32 codepoint, const std::string& context,
+ const CodepointSpan span) {
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+
+ auto begin_it = context_unicode.begin();
+ std::advance(begin_it, span.first);
+ auto end_it = context_unicode.begin();
+ std::advance(end_it, span.second);
+
+ return std::find(begin_it, end_it, codepoint) != end_it;
+}
+
+// Returns the first codepoint of the span.
+char32 FirstSpanCodepoint(const std::string& context,
+ const CodepointSpan span) {
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+
+ auto it = context_unicode.begin();
+ std::advance(it, span.first);
+ return *it;
+}
+
+// Returns the last codepoint of the span.
+char32 LastSpanCodepoint(const std::string& context, const CodepointSpan span) {
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+
+ auto it = context_unicode.begin();
+ std::advance(it, span.second - 1);
+ return *it;
+}
+
+} // namespace
+
+// If the first or the last codepoint of the given span is a bracket, the
+// bracket is stripped if the span does not contain its corresponding paired
+// version.
+CodepointSpan StripUnpairedBrackets(const std::string& context,
+ CodepointSpan span, const UniLib& unilib) {
+ if (context.empty()) {
+ return span;
+ }
+
+ const char32 begin_char = FirstSpanCodepoint(context, span);
+ const char32 paired_begin_char = unilib.GetPairedBracket(begin_char);
+ if (paired_begin_char != begin_char) {
+ if (!unilib.IsOpeningBracket(begin_char) ||
+ !IsCodepointInSpan(paired_begin_char, context, span)) {
+ ++span.first;
+ }
+ }
+
+ if (span.first == span.second) {
+ return span;
+ }
+
+ const char32 end_char = LastSpanCodepoint(context, span);
+ const char32 paired_end_char = unilib.GetPairedBracket(end_char);
+ if (paired_end_char != end_char) {
+ if (!unilib.IsClosingBracket(end_char) ||
+ !IsCodepointInSpan(paired_end_char, context, span)) {
+ --span.second;
+ }
+ }
+
+ // Should not happen, but let's make sure.
+ if (span.first > span.second) {
+ TC_LOG(WARNING) << "Inverse indices result: " << span.first << ", "
+ << span.second;
+ span.second = span.first;
+ }
+
+ return span;
+}
+
+} // namespace libtextclassifier2
diff --git a/strip-unpaired-brackets.h b/strip-unpaired-brackets.h
new file mode 100644
index 0000000..2d7893e
--- /dev/null
+++ b/strip-unpaired-brackets.h
@@ -0,0 +1,33 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_STRIP_UNPAIRED_BRACKETS_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_STRIP_UNPAIRED_BRACKETS_H_
+
+#include <string>
+
+#include "types.h"
+#include "util/utf8/unilib.h"
+
+namespace libtextclassifier2 {
+// If the first or the last codepoint of the given span is a bracket, the
+// bracket is stripped if the span does not contain its corresponding paired
+// version.
+CodepointSpan StripUnpairedBrackets(const std::string& context,
+ CodepointSpan span, const UniLib& unilib);
+} // namespace libtextclassifier2
+
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_STRIP_UNPAIRED_BRACKETS_H_
diff --git a/strip-unpaired-brackets_test.cc b/strip-unpaired-brackets_test.cc
new file mode 100644
index 0000000..fb99d82
--- /dev/null
+++ b/strip-unpaired-brackets_test.cc
@@ -0,0 +1,51 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "strip-unpaired-brackets.h"
+
+#include "gtest/gtest.h"
+
+namespace libtextclassifier2 {
+namespace {
+
+TEST(StripUnpairedBracketsTest, StripUnpairedBrackets) {
+ UniLib unilib;
+ // If the brackets match, nothing gets stripped.
+ EXPECT_EQ(StripUnpairedBrackets("call me (123) 456 today", {8, 17}, unilib),
+ std::make_pair(8, 17));
+ EXPECT_EQ(StripUnpairedBrackets("call me (123 456) today", {8, 17}, unilib),
+ std::make_pair(8, 17));
+
+ // If the brackets don't match, they get stripped.
+ EXPECT_EQ(StripUnpairedBrackets("call me (123 456 today", {8, 16}, unilib),
+ std::make_pair(9, 16));
+ EXPECT_EQ(StripUnpairedBrackets("call me )123 456 today", {8, 16}, unilib),
+ std::make_pair(9, 16));
+ EXPECT_EQ(StripUnpairedBrackets("call me 123 456) today", {8, 16}, unilib),
+ std::make_pair(8, 15));
+ EXPECT_EQ(StripUnpairedBrackets("call me 123 456( today", {8, 16}, unilib),
+ std::make_pair(8, 15));
+
+ // Strips brackets correctly from length-1 selections that consist of
+ // a bracket only.
+ EXPECT_EQ(StripUnpairedBrackets("call me at ) today", {11, 12}, unilib),
+ std::make_pair(12, 12));
+ EXPECT_EQ(StripUnpairedBrackets("call me at ( today", {11, 12}, unilib),
+ std::make_pair(12, 12));
+}
+
+} // namespace
+} // namespace libtextclassifier2
diff --git a/common/config.h b/tensor-view.cc
similarity index 67%
rename from common/config.h
rename to tensor-view.cc
index b883e95..4acadc5 100644
--- a/common/config.h
+++ b/tensor-view.cc
@@ -14,16 +14,18 @@
* limitations under the License.
*/
-#ifndef LIBTEXTCLASSIFIER_COMMON_CONFIG_H_
-#define LIBTEXTCLASSIFIER_COMMON_CONFIG_H_
+#include "tensor-view.h"
-#ifndef PORTABLE_SAFT_MOBILE
-#if defined(__ANDROID__) || defined(__APPLE__)
-#define PORTABLE_SAFT_MOBILE 1
-#else
-#define PORTABLE_SAFT_MOBILE 0
-#endif
+namespace libtextclassifier2 {
-#endif
+namespace internal {
+int NumberOfElements(const std::vector<int>& shape) {
+ int size = 1;
+ for (const int dim : shape) {
+ size *= dim;
+ }
+ return size;
+}
+} // namespace internal
-#endif // LIBTEXTCLASSIFIER_COMMON_CONFIG_H_
+} // namespace libtextclassifier2
diff --git a/tensor-view.h b/tensor-view.h
new file mode 100644
index 0000000..69788c8
--- /dev/null
+++ b/tensor-view.h
@@ -0,0 +1,72 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TENSOR_VIEW_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TENSOR_VIEW_H_
+
+#include <algorithm>
+#include <vector>
+
+namespace libtextclassifier2 {
+namespace internal {
+// Computes the number of elements in a tensor of given shape.
+int NumberOfElements(const std::vector<int>& shape);
+} // namespace internal
+
+// View of a tensor of given type.
+// NOTE: Does not own the underlying memory, so the contract about its validity
+// needs to be specified on the interface that returns it.
+template <typename T>
+class TensorView {
+ public:
+ TensorView(const T* data, const std::vector<int>& shape)
+ : data_(data), shape_(shape), size_(internal::NumberOfElements(shape)) {}
+
+ static TensorView Invalid() {
+ static std::vector<int>& invalid_shape =
+ *[]() { return new std::vector<int>(0); }();
+ return TensorView(nullptr, invalid_shape);
+ }
+
+ bool is_valid() const { return data_ != nullptr; }
+
+ const std::vector<int>& shape() const { return shape_; }
+
+ int dim(int i) const { return shape_[i]; }
+
+ int dims() const { return shape_.size(); }
+
+ const T* data() const { return data_; }
+
+ int size() const { return size_; }
+
+ bool copy_to(T* dest, int dest_size) const {
+ if (dest_size < size_) {
+ return false;
+ }
+ std::copy(data_, data_ + size_, dest);
+ return true;
+ }
+
+ private:
+ const T* data_ = nullptr;
+ const std::vector<int> shape_;
+ const int size_;
+};
+
+} // namespace libtextclassifier2
+
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TENSOR_VIEW_H_
diff --git a/tensor-view_test.cc b/tensor-view_test.cc
new file mode 100644
index 0000000..d50fac7
--- /dev/null
+++ b/tensor-view_test.cc
@@ -0,0 +1,52 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "tensor-view.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier2 {
+namespace {
+
+TEST(TensorViewTest, TestSize) {
+ std::vector<float> data{0.1, 0.2, 0.3, 0.4, 0.5, 0.6};
+ const TensorView<float> tensor(data.data(), {3, 1, 2});
+ EXPECT_TRUE(tensor.is_valid());
+ EXPECT_EQ(tensor.shape(), (std::vector<int>{3, 1, 2}));
+ EXPECT_EQ(tensor.data(), data.data());
+ EXPECT_EQ(tensor.size(), 6);
+ EXPECT_EQ(tensor.dims(), 3);
+ EXPECT_EQ(tensor.dim(0), 3);
+ EXPECT_EQ(tensor.dim(1), 1);
+ EXPECT_EQ(tensor.dim(2), 2);
+ std::vector<float> output_data(6);
+ EXPECT_TRUE(tensor.copy_to(output_data.data(), output_data.size()));
+ EXPECT_EQ(data, output_data);
+
+ // Should not copy when the output is small.
+ std::vector<float> small_output_data{-1, -1, -1};
+ EXPECT_FALSE(
+ tensor.copy_to(small_output_data.data(), small_output_data.size()));
+ // The output buffer should not be changed.
+ EXPECT_EQ(small_output_data, (std::vector<float>{-1, -1, -1}));
+
+ const TensorView<float> invalid_tensor = TensorView<float>::Invalid();
+ EXPECT_FALSE(invalid_tensor.is_valid());
+}
+
+} // namespace
+} // namespace libtextclassifier2
diff --git a/test_data/dummy.fb b/test_data/dummy.fb
new file mode 100644
index 0000000..4fec970
--- /dev/null
+++ b/test_data/dummy.fb
Binary files differ
diff --git a/test_data/test_model.fb b/test_data/test_model.fb
new file mode 100644
index 0000000..f62d9fe
--- /dev/null
+++ b/test_data/test_model.fb
Binary files differ
diff --git a/test_data/wrong_embeddings.fb b/test_data/wrong_embeddings.fb
new file mode 100644
index 0000000..513fcf5
--- /dev/null
+++ b/test_data/wrong_embeddings.fb
Binary files differ
diff --git a/tests/testdata/langid.model b/tests/testdata/langid.model
deleted file mode 100644
index 6b68223..0000000
--- a/tests/testdata/langid.model
+++ /dev/null
Binary files differ
diff --git a/tests/testdata/smartselection.model b/tests/testdata/smartselection.model
deleted file mode 100644
index 645303d..0000000
--- a/tests/testdata/smartselection.model
+++ /dev/null
Binary files differ
diff --git a/text-classifier.cc b/text-classifier.cc
new file mode 100644
index 0000000..1ee7e56
--- /dev/null
+++ b/text-classifier.cc
@@ -0,0 +1,592 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "text-classifier.h"
+
+#include <algorithm>
+#include <cctype>
+#include <cmath>
+#include <iterator>
+#include <numeric>
+
+#include "util/base/logging.h"
+#include "util/math/softmax.h"
+#include "util/utf8/unicodetext.h"
+
+namespace libtextclassifier2 {
+namespace {
+const Model* LoadAndVerifyModel(const void* addr, int size) {
+ const Model* model = flatbuffers::GetRoot<Model>(addr);
+
+ flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
+ if (model->Verify(verifier)) {
+ return model;
+ } else {
+ return nullptr;
+ }
+}
+} // namespace
+
+std::unique_ptr<TextClassifier> TextClassifier::FromUnownedBuffer(
+ const char* buffer, int size) {
+ const Model* model = LoadAndVerifyModel(buffer, size);
+ if (model == nullptr) {
+ return nullptr;
+ }
+
+ auto classifier = std::unique_ptr<TextClassifier>(new TextClassifier(model));
+ if (!classifier->IsInitialized()) {
+ return nullptr;
+ }
+
+ return classifier;
+}
+
+std::unique_ptr<TextClassifier> TextClassifier::FromScopedMmap(
+ std::unique_ptr<ScopedMmap>* mmap) {
+ if (!(*mmap)->handle().ok()) {
+ TC_VLOG(1) << "Mmap failed.";
+ return nullptr;
+ }
+
+ const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
+ (*mmap)->handle().num_bytes());
+ if (!model) {
+ TC_LOG(ERROR) << "Model verification failed.";
+ return nullptr;
+ }
+
+ auto classifier =
+ std::unique_ptr<TextClassifier>(new TextClassifier(mmap, model));
+ if (!classifier->IsInitialized()) {
+ return nullptr;
+ }
+
+ return classifier;
+}
+
+std::unique_ptr<TextClassifier> TextClassifier::FromFileDescriptor(int fd,
+ int offset,
+ int size) {
+ std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
+ return FromScopedMmap(&mmap);
+}
+
+std::unique_ptr<TextClassifier> TextClassifier::FromFileDescriptor(int fd) {
+ std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
+ return FromScopedMmap(&mmap);
+}
+
+std::unique_ptr<TextClassifier> TextClassifier::FromPath(
+ const std::string& path) {
+ std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
+ return FromScopedMmap(&mmap);
+}
+
+void TextClassifier::ValidateAndInitialize() {
+ if (model_ == nullptr) {
+ TC_LOG(ERROR) << "No model specified.";
+ initialized_ = false;
+ return;
+ }
+
+ if (!model_->selection_options()) {
+ TC_LOG(ERROR) << "No selection options.";
+ initialized_ = false;
+ return;
+ }
+
+ if (!model_->classification_options()) {
+ TC_LOG(ERROR) << "No classification options.";
+ initialized_ = false;
+ return;
+ }
+
+ if (!model_->selection_feature_options()) {
+ TC_LOG(ERROR) << "No selection feature options.";
+ initialized_ = false;
+ return;
+ }
+
+ if (!model_->classification_feature_options()) {
+ TC_LOG(ERROR) << "No classification feature options.";
+ initialized_ = false;
+ return;
+ }
+
+ if (!model_->classification_feature_options()->bounds_sensitive_features()) {
+ TC_LOG(ERROR) << "No classification bounds sensitive feature options.";
+ initialized_ = false;
+ return;
+ }
+
+ if (!model_->selection_feature_options()->bounds_sensitive_features()) {
+ TC_LOG(ERROR) << "No selection bounds sensitive feature options.";
+ initialized_ = false;
+ return;
+ }
+
+ if (!model_->selection_model()) {
+ TC_LOG(ERROR) << "No selection model.";
+ initialized_ = false;
+ return;
+ }
+
+ if (!model_->embedding_model()) {
+ TC_LOG(ERROR) << "No embedding model.";
+ initialized_ = false;
+ return;
+ }
+
+ if (!model_->classification_model()) {
+ TC_LOG(ERROR) << "No clf model.";
+ initialized_ = false;
+ return;
+ }
+
+ if (model_->regex_options()) {
+ InitializeRegexModel();
+ }
+
+ embedding_executor_.reset(new TFLiteEmbeddingExecutor(
+ flatbuffers::GetRoot<tflite::Model>(model_->embedding_model()->data())));
+ if (!embedding_executor_ || !embedding_executor_->IsReady()) {
+ TC_LOG(ERROR) << "Could not initialize embedding executor.";
+ initialized_ = false;
+ return;
+ }
+ selection_executor_.reset(new ModelExecutor(
+ flatbuffers::GetRoot<tflite::Model>(model_->selection_model()->data())));
+ if (!selection_executor_) {
+ TC_LOG(ERROR) << "Could not initialize selection executor.";
+ initialized_ = false;
+ return;
+ }
+ classification_executor_.reset(
+ new ModelExecutor(flatbuffers::GetRoot<tflite::Model>(
+ model_->classification_model()->data())));
+ if (!classification_executor_) {
+ TC_LOG(ERROR) << "Could not initialize classification executor.";
+ initialized_ = false;
+ return;
+ }
+
+ selection_feature_processor_.reset(
+ new FeatureProcessor(model_->selection_feature_options(), unilib_.get()));
+ classification_feature_processor_.reset(new FeatureProcessor(
+ model_->classification_feature_options(), unilib_.get()));
+
+ initialized_ = true;
+}
+
+void TextClassifier::InitializeRegexModel() {
+ if (!model_->regex_options()->patterns()) {
+ initialized_ = false;
+ TC_LOG(ERROR) << "No patterns in the regex config.";
+ return;
+ }
+
+ // Initialize pattern recognizers.
+ for (const auto& regex_pattern : *model_->regex_options()->patterns()) {
+ std::unique_ptr<UniLib::RegexPattern> compiled_pattern(
+ unilib_->CreateRegexPattern(regex_pattern->pattern()->c_str()));
+
+ if (!compiled_pattern) {
+ TC_LOG(WARNING) << "Failed to load pattern"
+ << regex_pattern->pattern()->str();
+ continue;
+ }
+
+ regex_patterns_.push_back(
+ {regex_pattern->collection_name()->str(), std::move(compiled_pattern)});
+ }
+}
+
+namespace {
+
+int CountDigits(const std::string& str, CodepointSpan selection_indices) {
+ int count = 0;
+ int i = 0;
+ const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false);
+ for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) {
+ if (i >= selection_indices.first && i < selection_indices.second &&
+ isdigit(*it)) {
+ ++count;
+ }
+ }
+ return count;
+}
+
+std::string ExtractSelection(const std::string& context,
+ CodepointSpan selection_indices) {
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+ auto selection_begin = context_unicode.begin();
+ std::advance(selection_begin, selection_indices.first);
+ auto selection_end = context_unicode.begin();
+ std::advance(selection_end, selection_indices.second);
+ return UnicodeText::UTF8Substring(selection_begin, selection_end);
+}
+
+} // namespace
+
+CodepointSpan TextClassifier::SuggestSelection(
+ const std::string& context, CodepointSpan click_indices) const {
+ if (!initialized_) {
+ TC_LOG(ERROR) << "Not initialized";
+ return click_indices;
+ }
+
+ const int context_codepoint_size =
+ UTF8ToUnicodeText(context, /*do_copy=*/false).size();
+
+ if (click_indices.first < 0 || click_indices.second < 0 ||
+ click_indices.first >= context_codepoint_size ||
+ click_indices.second > context_codepoint_size ||
+ click_indices.first >= click_indices.second) {
+ TC_VLOG(1) << "Trying to run SuggestSelection with invalid indices: "
+ << click_indices.first << " " << click_indices.second;
+ return click_indices;
+ }
+
+ std::vector<Token> tokens;
+ int click_pos;
+ selection_feature_processor_->TokenizeAndFindClick(context, click_indices,
+ &tokens, &click_pos);
+ if (click_pos == kInvalidIndex) {
+ TC_VLOG(1) << "Could not calculate the click position.";
+ return click_indices;
+ }
+
+ const int symmetry_context_size =
+ model_->selection_options()->symmetry_context_size();
+ const int max_selection_span =
+ selection_feature_processor_->GetOptions()->max_selection_span();
+ const FeatureProcessorOptions_::BoundsSensitiveFeatures*
+ bounds_sensitive_features =
+ model_->selection_feature_options()->bounds_sensitive_features();
+
+ // The symmetry context span is the clicked token with symmetry_context_size
+ // tokens on either side.
+ const TokenSpan symmetry_context_span = IntersectTokenSpans(
+ ExpandTokenSpan(SingleTokenSpan(click_pos),
+ /*num_tokens_left=*/symmetry_context_size,
+ /*num_tokens_right=*/symmetry_context_size),
+ {0, tokens.size()});
+
+ // The extraction span is the symmetry context span expanded to include
+ // max_selection_span tokens on either side, which is how far a selection can
+ // stretch from the click, plus a relevant number of tokens outside of the
+ // bounds of the selection.
+ const TokenSpan extraction_span = IntersectTokenSpans(
+ ExpandTokenSpan(symmetry_context_span,
+ /*num_tokens_left=*/max_selection_span +
+ bounds_sensitive_features->num_tokens_before(),
+ /*num_tokens_right=*/max_selection_span +
+ bounds_sensitive_features->num_tokens_after()),
+ {0, tokens.size()});
+
+ std::unique_ptr<CachedFeatures> cached_features;
+ if (!classification_feature_processor_->ExtractFeatures(
+ tokens, extraction_span, embedding_executor_.get(),
+ classification_feature_processor_->EmbeddingSize() +
+ classification_feature_processor_->DenseFeaturesCount(),
+ &cached_features)) {
+ TC_LOG(ERROR) << "Could not extract features.";
+ return click_indices;
+ }
+
+ std::vector<TokenSpan> chunks;
+ if (!Chunk(tokens.size(), /*span_of_interest=*/symmetry_context_span,
+ *cached_features, &chunks)) {
+ TC_LOG(ERROR) << "Could not chunk.";
+ return click_indices;
+ }
+
+ CodepointSpan result = click_indices;
+ for (const TokenSpan& chunk : chunks) {
+ if (chunk.first <= click_pos && click_pos < chunk.second) {
+ result = selection_feature_processor_->StripBoundaryCodepoints(
+ context, TokenSpanToCodepointSpan(tokens, chunk));
+ break;
+ }
+ }
+
+ if (model_->selection_options()->strip_unpaired_brackets()) {
+ const CodepointSpan stripped_result =
+ StripUnpairedBrackets(context, result, *unilib_);
+ if (stripped_result.first != stripped_result.second) {
+ result = stripped_result;
+ }
+ }
+
+ return result;
+}
+
+std::vector<std::pair<std::string, float>> TextClassifier::ClassifyText(
+ const std::string& context, CodepointSpan selection_indices) const {
+ if (!initialized_) {
+ TC_LOG(ERROR) << "Not initialized";
+ return {};
+ }
+
+ if (std::get<0>(selection_indices) >= std::get<1>(selection_indices)) {
+ TC_VLOG(1) << "Trying to run ClassifyText with invalid indices: "
+ << std::get<0>(selection_indices) << " "
+ << std::get<1>(selection_indices);
+ return {};
+ }
+
+ // Check whether any of the regular expressions match.
+ const std::string selection_text =
+ ExtractSelection(context, selection_indices);
+ for (const CompiledRegexPattern& regex_pattern : regex_patterns_) {
+ if (regex_pattern.pattern->Matches(selection_text)) {
+ return {{regex_pattern.collection_name, 1.0}};
+ }
+ }
+
+ const FeatureProcessorOptions_::BoundsSensitiveFeatures*
+ bounds_sensitive_features =
+ model_->classification_feature_options()->bounds_sensitive_features();
+
+ std::vector<Token> tokens;
+ classification_feature_processor_->TokenizeAndFindClick(
+ context, selection_indices, &tokens, /*click_pos=*/nullptr);
+ const TokenSpan selection_token_span =
+ CodepointSpanToTokenSpan(tokens, selection_indices);
+
+ if (selection_token_span.first == kInvalidIndex ||
+ selection_token_span.second == kInvalidIndex) {
+ return {};
+ }
+
+ // The extraction span is the selection span expanded to include a relevant
+ // number of tokens outside of the bounds of the selection.
+ const TokenSpan extraction_span = IntersectTokenSpans(
+ ExpandTokenSpan(selection_token_span,
+ bounds_sensitive_features->num_tokens_before(),
+ bounds_sensitive_features->num_tokens_after()),
+ {0, tokens.size()});
+
+ std::unique_ptr<CachedFeatures> cached_features;
+ if (!classification_feature_processor_->ExtractFeatures(
+ tokens, extraction_span, embedding_executor_.get(),
+ classification_feature_processor_->EmbeddingSize() +
+ classification_feature_processor_->DenseFeaturesCount(),
+ &cached_features)) {
+ TC_LOG(ERROR) << "Could not extract features.";
+ return {};
+ }
+
+ const std::vector<float> features =
+ cached_features->Get(selection_token_span);
+
+ TensorView<float> logits =
+ classification_executor_->ComputeLogits(TensorView<float>(
+ features.data(), {1, static_cast<int>(features.size())}));
+ if (!logits.is_valid()) {
+ TC_LOG(ERROR) << "Couldn't compute logits.";
+ return {};
+ }
+
+ if (logits.dims() != 2 || logits.dim(0) != 1 ||
+ logits.dim(1) != classification_feature_processor_->NumCollections()) {
+ TC_LOG(ERROR) << "Mismatching output";
+ return {};
+ }
+
+ const std::vector<float> scores =
+ ComputeSoftmax(logits.data(), logits.dim(1));
+
+ std::vector<std::pair<std::string, float>> result(scores.size());
+ for (int i = 0; i < scores.size(); i++) {
+ result[i] = {classification_feature_processor_->LabelToCollection(i),
+ scores[i]};
+ }
+ std::sort(result.begin(), result.end(),
+ [](const std::pair<std::string, float>& a,
+ const std::pair<std::string, float>& b) {
+ return a.second > b.second;
+ });
+
+ // Phone class sanity check.
+ if (result.begin()->first == kPhoneCollection) {
+ const int digit_count = CountDigits(context, selection_indices);
+ if (digit_count <
+ model_->classification_options()->phone_min_num_digits() ||
+ digit_count >
+ model_->classification_options()->phone_max_num_digits()) {
+ return {{kOtherCollection, 1.0}};
+ }
+ }
+
+ return result;
+}
+
+std::vector<AnnotatedSpan> TextClassifier::Annotate(
+ const std::string& context) const {
+ const UnicodeText context_unicode = UTF8ToUnicodeText(context,
+ /*do_copy=*/false);
+
+ std::vector<TokenSpan> chunks;
+ for (const UnicodeTextRange& line :
+ selection_feature_processor_->SplitContext(context_unicode)) {
+ const std::string line_str =
+ UnicodeText::UTF8Substring(line.first, line.second);
+
+ std::vector<Token> tokens;
+ selection_feature_processor_->TokenizeAndFindClick(
+ line_str, {0, std::distance(line.first, line.second)}, &tokens,
+ /*click_pos=*/nullptr);
+ const TokenSpan full_line_span = {0, tokens.size()};
+
+ std::unique_ptr<CachedFeatures> cached_features;
+ if (!classification_feature_processor_->ExtractFeatures(
+ tokens, full_line_span, embedding_executor_.get(),
+ classification_feature_processor_->EmbeddingSize() +
+ classification_feature_processor_->DenseFeaturesCount(),
+ &cached_features)) {
+ TC_LOG(ERROR) << "Could not extract features.";
+ continue;
+ }
+
+ std::vector<TokenSpan> local_chunks;
+ if (!Chunk(tokens.size(), /*span_of_interest=*/full_line_span,
+ *cached_features, &local_chunks)) {
+ TC_LOG(ERROR) << "Could not chunk.";
+ continue;
+ }
+
+ const int offset = std::distance(context_unicode.begin(), line.first);
+ for (const TokenSpan& chunk : local_chunks) {
+ const CodepointSpan codepoint_span =
+ selection_feature_processor_->StripBoundaryCodepoints(
+ line_str, TokenSpanToCodepointSpan(tokens, chunk));
+ chunks.push_back(
+ {codepoint_span.first + offset, codepoint_span.second + offset});
+ }
+ }
+
+ std::vector<AnnotatedSpan> result;
+ for (const CodepointSpan& chunk : chunks) {
+ result.emplace_back();
+ result.back().span = chunk;
+ result.back().classification = ClassifyText(context, chunk);
+ }
+ return result;
+}
+
+bool TextClassifier::Chunk(int num_tokens, const TokenSpan& span_of_interest,
+ const CachedFeatures& cached_features,
+ std::vector<TokenSpan>* chunks) const {
+ const int max_selection_span =
+ selection_feature_processor_->GetOptions()->max_selection_span();
+ const int max_chunk_length = selection_feature_processor_->GetOptions()
+ ->selection_reduced_output_space()
+ ? max_selection_span + 1
+ : 2 * max_selection_span + 1;
+
+ struct ScoredChunk {
+ bool operator<(const ScoredChunk& that) const { return score < that.score; }
+
+ TokenSpan token_span;
+ float score;
+ };
+
+ // The inference span is the span of interest expanded to include
+ // max_selection_span tokens on either side, which is how far a selection can
+ // stretch from the click.
+ const TokenSpan inference_span = IntersectTokenSpans(
+ ExpandTokenSpan(span_of_interest,
+ /*num_tokens_left=*/max_selection_span,
+ /*num_tokens_right=*/max_selection_span),
+ {0, num_tokens});
+
+ std::vector<ScoredChunk> scored_chunks;
+ // Iterate over chunk candidates that:
+ // - Are contained in the inference span
+ // - Have a non-empty intersection with the span of interest
+ // - Are at least one token long
+ // - Are not longer than the maximum chunk length
+ for (int start = inference_span.first; start < span_of_interest.second;
+ ++start) {
+ const int leftmost_end_index = std::max(start, span_of_interest.first) + 1;
+ for (int end = leftmost_end_index;
+ end <= inference_span.second && end - start <= max_chunk_length;
+ ++end) {
+ const std::vector<float> features = cached_features.Get({start, end});
+ TensorView<float> logits =
+ selection_executor_->ComputeLogits(TensorView<float>(
+ features.data(), {1, static_cast<int>(features.size())}));
+
+ if (!logits.is_valid()) {
+ TC_LOG(ERROR) << "Couldn't compute logits.";
+ return false;
+ }
+
+ if (logits.dims() != 2 || logits.dim(0) != 1 || logits.dim(1) != 1) {
+ TC_LOG(ERROR) << "Mismatching output";
+ return false;
+ }
+
+ scored_chunks.push_back(ScoredChunk{{start, end}, logits.data()[0]});
+ }
+ }
+
+ std::sort(scored_chunks.rbegin(), scored_chunks.rend());
+
+ // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
+ // them greedily as long as they do not overlap with any previously picked
+ // chunks.
+ std::vector<bool> token_used(TokenSpanSize(inference_span));
+ chunks->clear();
+ for (const ScoredChunk& scored_chunk : scored_chunks) {
+ bool feasible = true;
+ for (int i = scored_chunk.token_span.first;
+ i < scored_chunk.token_span.second; ++i) {
+ if (token_used[i - inference_span.first]) {
+ feasible = false;
+ break;
+ }
+ }
+
+ if (!feasible) {
+ continue;
+ }
+
+ for (int i = scored_chunk.token_span.first;
+ i < scored_chunk.token_span.second; ++i) {
+ token_used[i - inference_span.first] = true;
+ }
+
+ chunks->push_back(scored_chunk.token_span);
+ }
+
+ std::sort(chunks->begin(), chunks->end());
+
+ return true;
+}
+
+const Model* ViewModel(const void* buffer, int size) {
+ if (!buffer) {
+ return nullptr;
+ }
+
+ return LoadAndVerifyModel(buffer, size);
+}
+
+} // namespace libtextclassifier2
diff --git a/text-classifier.h b/text-classifier.h
new file mode 100644
index 0000000..cd84eb4
--- /dev/null
+++ b/text-classifier.h
@@ -0,0 +1,143 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Inference code for the text classification model.
+
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TEXT_CLASSIFIER_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TEXT_CLASSIFIER_H_
+
+#include <memory>
+#include <set>
+#include <string>
+#include <vector>
+
+#include "feature-processor.h"
+#include "model-executor.h"
+#include "model_generated.h"
+#include "strip-unpaired-brackets.h"
+#include "types.h"
+#include "util/memory/mmap.h"
+#include "util/utf8/unilib.h"
+
+namespace libtextclassifier2 {
+
+// A text processing model that provides text classification, annotation,
+// selection suggestion for various types.
+// NOTE: This class is not thread-safe.
+class TextClassifier {
+ public:
+ static std::unique_ptr<TextClassifier> FromUnownedBuffer(const char* buffer,
+ int size);
+ // Takes ownership of the mmap.
+ static std::unique_ptr<TextClassifier> FromScopedMmap(
+ std::unique_ptr<ScopedMmap>* mmap);
+ static std::unique_ptr<TextClassifier> FromFileDescriptor(int fd, int offset,
+ int size);
+ static std::unique_ptr<TextClassifier> FromFileDescriptor(int fd);
+ static std::unique_ptr<TextClassifier> FromPath(const std::string& path);
+
+ // Returns true if the model is ready for use.
+ bool IsInitialized() { return initialized_; }
+
+ // Runs inference for given a context and current selection (i.e. index
+ // of the first and one past last selected characters (utf8 codepoint
+ // offsets)). Returns the indices (utf8 codepoint offsets) of the selection
+ // beginning character and one past selection end character.
+ // Returns the original click_indices if an error occurs.
+ // NOTE: The selection indices are passed in and returned in terms of
+ // UTF8 codepoints (not bytes).
+ // Requires that the model is a smart selection model.
+ CodepointSpan SuggestSelection(const std::string& context,
+ CodepointSpan click_indices) const;
+
+ // Classifies the selected text given the context string.
+ // Returns an empty result if an error occurs.
+ std::vector<std::pair<std::string, float>> ClassifyText(
+ const std::string& context, CodepointSpan selection_indices) const;
+
+ // Annotates given input text. The annotations should cover the whole input
+ // context except for whitespaces, and are sorted by their position in the
+ // context string.
+ std::vector<AnnotatedSpan> Annotate(const std::string& context) const;
+
+ protected:
+ // Constructs and initializes text classifier from given model.
+ // Takes ownership of 'mmap', and thus owns the buffer that backs 'model'.
+ TextClassifier(std::unique_ptr<ScopedMmap>* mmap, const Model* model)
+ : model_(model), mmap_(std::move(*mmap)), unilib_(new UniLib()) {
+ ValidateAndInitialize();
+ }
+
+ // Constructs, validates and initializes text classifier from given model.
+ // Does not own the buffer that backs 'model'.
+ explicit TextClassifier(const Model* model)
+ : model_(model), unilib_(new UniLib()) {
+ ValidateAndInitialize();
+ }
+
+ // Checks that model contains all required fields, and initializes internal
+ // datastructures.
+ void ValidateAndInitialize();
+
+ // Initializes regular expressions for the regex model.
+ void InitializeRegexModel();
+
+ // Groups the tokens into chunks. A chunk is a token span that should be the
+ // suggested selection when any of its contained tokens is clicked. The chunks
+ // are non-overlapping and are sorted by their position in the context string.
+ // "num_tokens" is the total number of tokens available (as this method does
+ // not need the actual vector of tokens).
+ // "span_of_interest" is a span of all the tokens that could be clicked.
+ // The resulting chunks all have to overlap with it and they cover this span
+ // completely. The first and last chunk might extend beyond it.
+ // The chunks vector is cleared before filling.
+ bool Chunk(int num_tokens, const TokenSpan& span_of_interest,
+ const CachedFeatures& cached_features,
+ std::vector<TokenSpan>* chunks) const;
+
+ // Collection name for other.
+ const std::string kOtherCollection = "other";
+
+ // Collection name for phone.
+ const std::string kPhoneCollection = "phone";
+
+ const Model* model_;
+
+ std::unique_ptr<ModelExecutor> selection_executor_;
+ std::unique_ptr<ModelExecutor> classification_executor_;
+ std::unique_ptr<EmbeddingExecutor> embedding_executor_;
+
+ std::unique_ptr<FeatureProcessor> selection_feature_processor_;
+ std::unique_ptr<FeatureProcessor> classification_feature_processor_;
+
+ private:
+ struct CompiledRegexPattern {
+ std::string collection_name;
+ std::unique_ptr<UniLib::RegexPattern> pattern;
+ };
+
+ std::unique_ptr<ScopedMmap> mmap_;
+ bool initialized_ = false;
+ std::vector<CompiledRegexPattern> regex_patterns_;
+ std::unique_ptr<UniLib> unilib_;
+};
+
+// Interprets the buffer as a Model flatbuffer and returns it for reading.
+const Model* ViewModel(const void* buffer, int size);
+
+} // namespace libtextclassifier2
+
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TEXT_CLASSIFIER_H_
diff --git a/text-classifier_test.cc b/text-classifier_test.cc
new file mode 100644
index 0000000..82904e5
--- /dev/null
+++ b/text-classifier_test.cc
@@ -0,0 +1,271 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "text-classifier.h"
+
+#include <fstream>
+#include <iostream>
+#include <memory>
+#include <string>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier2 {
+namespace {
+
+using testing::ElementsAreArray;
+using testing::Pair;
+
+std::string FirstResult(
+ const std::vector<std::pair<std::string, float>>& results) {
+ if (results.empty()) {
+ return "<INVALID RESULTS>";
+ }
+ return results[0].first;
+}
+
+MATCHER_P3(IsAnnotatedSpan, start, end, best_class, "") {
+ return testing::Value(arg.span, Pair(start, end)) &&
+ testing::Value(FirstResult(arg.classification), best_class);
+}
+
+std::string ReadFile(const std::string& file_name) {
+ std::ifstream file_stream(file_name);
+ return std::string(std::istreambuf_iterator<char>(file_stream), {});
+}
+
+std::string GetModelPath() {
+ return LIBTEXTCLASSIFIER_TEST_DATA_DIR;
+}
+
+TEST(TextClassifierTest, EmbeddingExecutorLoadingFails) {
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromPath(GetModelPath() + "wrong_embeddings.fb");
+ EXPECT_FALSE(classifier);
+}
+
+TEST(TextClassifierTest, ClassifyText) {
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromPath(GetModelPath() + "test_model.fb");
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ("other",
+ FirstResult(classifier->ClassifyText(
+ "this afternoon Barack Obama gave a speech at", {15, 27})));
+ EXPECT_EQ("other",
+ FirstResult(classifier->ClassifyText("you@android.com", {0, 15})));
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
+ "Contact me at you@android.com", {14, 29})));
+ EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
+ "Call me at (800) 123-456 today", {11, 24})));
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
+ "Visit www.google.com every today!", {6, 20})));
+
+ // More lines.
+ EXPECT_EQ("other",
+ FirstResult(classifier->ClassifyText(
+ "this afternoon Barack Obama gave a speech at|Visit "
+ "www.google.com every today!|Call me at (800) 123-456 today.",
+ {15, 27})));
+ EXPECT_EQ("other",
+ FirstResult(classifier->ClassifyText(
+ "this afternoon Barack Obama gave a speech at|Visit "
+ "www.google.com every today!|Call me at (800) 123-456 today.",
+ {51, 65})));
+ EXPECT_EQ("phone",
+ FirstResult(classifier->ClassifyText(
+ "this afternoon Barack Obama gave a speech at|Visit "
+ "www.google.com every today!|Call me at (800) 123-456 today.",
+ {90, 103})));
+
+ // Single word.
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText("obama", {0, 5})));
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText("asdf", {0, 4})));
+ EXPECT_EQ("<INVALID RESULTS>",
+ FirstResult(classifier->ClassifyText("asdf", {0, 0})));
+
+ // Junk.
+ EXPECT_EQ("<INVALID RESULTS>",
+ FirstResult(classifier->ClassifyText("", {0, 0})));
+ EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText(
+ "a\n\n\n\nx x x\n\n\n\n\n\n", {1, 5})));
+}
+
+TEST(TextClassifierTest, PhoneFiltering) {
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromPath(GetModelPath() + "test_model.fb");
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
+ "phone: (123) 456 789", {7, 20})));
+ EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
+ "phone: (123) 456 789,0001112", {7, 25})));
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
+ "phone: (123) 456 789,0001112", {7, 28})));
+}
+
+TEST(TextClassifierTest, SuggestSelection) {
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromPath(GetModelPath() + "test_model.fb");
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(classifier->SuggestSelection(
+ "this afternoon Barack Obama gave a speech at", {15, 21}),
+ std::make_pair(15, 21));
+
+ // Try passing whole string.
+ // If more than 1 token is specified, we should return back what entered.
+ EXPECT_EQ(
+ classifier->SuggestSelection("350 Third Street, Cambridge", {0, 27}),
+ std::make_pair(0, 27));
+
+ // Single letter.
+ EXPECT_EQ(classifier->SuggestSelection("a", {0, 1}), std::make_pair(0, 1));
+
+ // Single word.
+ EXPECT_EQ(classifier->SuggestSelection("asdf", {0, 4}), std::make_pair(0, 4));
+
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
+ std::make_pair(11, 23));
+
+ // Unpaired bracket stripping.
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at (857) 225 3556 today", {11, 16}),
+ std::make_pair(11, 25));
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at (857 225 3556 today", {11, 15}),
+ std::make_pair(12, 24));
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556) today", {11, 14}),
+ std::make_pair(11, 23));
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at )857 225 3556( today", {11, 15}),
+ std::make_pair(12, 24));
+
+ // If the resulting selection would be empty, the original span is returned.
+ EXPECT_EQ(classifier->SuggestSelection("call me at )( today", {11, 13}),
+ std::make_pair(11, 13));
+ EXPECT_EQ(classifier->SuggestSelection("call me at ( today", {11, 12}),
+ std::make_pair(11, 12));
+ EXPECT_EQ(classifier->SuggestSelection("call me at ) today", {11, 12}),
+ std::make_pair(11, 12));
+}
+
+TEST(TextClassifierTest, SuggestSelectionsAreSymmetric) {
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromPath(GetModelPath() + "test_model.fb");
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {0, 3}),
+ std::make_pair(0, 27));
+ EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
+ std::make_pair(0, 27));
+ EXPECT_EQ(
+ classifier->SuggestSelection("350 Third Street, Cambridge", {10, 16}),
+ std::make_pair(0, 27));
+ EXPECT_EQ(classifier->SuggestSelection("a\nb\nc\n350 Third Street, Cambridge",
+ {16, 22}),
+ std::make_pair(6, 33));
+}
+
+TEST(TextClassifierTest, SuggestSelectionWithNewLine) {
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromPath(GetModelPath() + "test_model.fb");
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(classifier->SuggestSelection("abc\n857 225 3556", {4, 7}),
+ std::make_pair(4, 16));
+ EXPECT_EQ(classifier->SuggestSelection("857 225 3556\nabc", {0, 3}),
+ std::make_pair(0, 12));
+}
+
+TEST(TextClassifierTest, SuggestSelectionWithPunctuation) {
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromPath(GetModelPath() + "test_model.fb");
+ ASSERT_TRUE(classifier);
+
+ // From the right.
+ EXPECT_EQ(classifier->SuggestSelection(
+ "this afternoon BarackObama, gave a speech at", {15, 26}),
+ std::make_pair(15, 26));
+
+ // From the right multiple.
+ EXPECT_EQ(classifier->SuggestSelection(
+ "this afternoon BarackObama,.,.,, gave a speech at", {15, 26}),
+ std::make_pair(15, 26));
+
+ // From the left multiple.
+ EXPECT_EQ(classifier->SuggestSelection(
+ "this afternoon ,.,.,,BarackObama gave a speech at", {21, 32}),
+ std::make_pair(21, 32));
+
+ // From both sides.
+ EXPECT_EQ(classifier->SuggestSelection(
+ "this afternoon !BarackObama,- gave a speech at", {16, 27}),
+ std::make_pair(16, 27));
+}
+
+TEST(TextClassifierTest, SuggestSelectionNoCrashWithJunk) {
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromPath(GetModelPath() + "test_model.fb");
+ ASSERT_TRUE(classifier);
+
+ // Try passing in bunch of invalid selections.
+ EXPECT_EQ(classifier->SuggestSelection("", {0, 27}), std::make_pair(0, 27));
+ EXPECT_EQ(classifier->SuggestSelection("", {-10, 27}),
+ std::make_pair(-10, 27));
+ EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {0, 27}),
+ std::make_pair(0, 27));
+ EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-30, 300}),
+ std::make_pair(-30, 300));
+ EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-10, -1}),
+ std::make_pair(-10, -1));
+ EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {100, 17}),
+ std::make_pair(100, 17));
+}
+
+TEST(TextClassifierTest, Annotate) {
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromPath(GetModelPath() + "test_model.fb");
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barak Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556.";
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({
+ IsAnnotatedSpan(0, 0, "<INVALID RESULTS>"),
+ IsAnnotatedSpan(2, 5, "other"),
+ IsAnnotatedSpan(6, 11, "other"),
+ IsAnnotatedSpan(12, 17, "other"),
+ IsAnnotatedSpan(18, 23, "other"),
+ IsAnnotatedSpan(24, 24, "<INVALID RESULTS>"),
+ IsAnnotatedSpan(27, 54, "address"),
+ IsAnnotatedSpan(55, 58, "other"),
+ IsAnnotatedSpan(59, 61, "other"),
+ IsAnnotatedSpan(62, 67, "other"),
+ IsAnnotatedSpan(68, 74, "other"),
+ IsAnnotatedSpan(75, 77, "other"),
+ IsAnnotatedSpan(78, 90, "phone"),
+ }));
+}
+
+// TODO(jacekj): Test the regex functionality.
+
+} // namespace
+} // namespace libtextclassifier2
diff --git a/textclassifier_jni.cc b/textclassifier_jni.cc
index 8740f4c..ecc6500 100644
--- a/textclassifier_jni.cc
+++ b/textclassifier_jni.cc
@@ -14,20 +14,20 @@
* limitations under the License.
*/
-// Simple JNI wrapper for the SmartSelection library.
+// JNI wrapper for the TextClassifier.
#include "textclassifier_jni.h"
#include <jni.h>
#include <vector>
-#include "lang_id/lang-id.h"
-#include "smartselect/text-classification-model.h"
+#include "text-classifier.h"
#include "util/java/scoped_local_ref.h"
+#include "util/memory/mmap.h"
-using libtextclassifier::ModelOptions;
-using libtextclassifier::TextClassificationModel;
-using libtextclassifier::nlp_core::lang_id::LangId;
+using libtextclassifier2::AnnotatedSpan;
+using libtextclassifier2::Model;
+using libtextclassifier2::TextClassifier;
namespace {
@@ -101,17 +101,17 @@
} // namespace
-namespace libtextclassifier {
+namespace libtextclassifier2 {
-using libtextclassifier::CodepointSpan;
+using libtextclassifier2::CodepointSpan;
namespace {
CodepointSpan ConvertIndicesBMPUTF8(const std::string& utf8_str,
CodepointSpan orig_indices,
bool from_utf8) {
- const libtextclassifier::UnicodeText unicode_str =
- libtextclassifier::UTF8ToUnicodeText(utf8_str, /*do_copy=*/false);
+ const libtextclassifier2::UnicodeText unicode_str =
+ libtextclassifier2::UTF8ToUnicodeText(utf8_str, /*do_copy=*/false);
int unicode_index = 0;
int bmp_index = 0;
@@ -155,77 +155,78 @@
} // namespace
CodepointSpan ConvertIndicesBMPToUTF8(const std::string& utf8_str,
- CodepointSpan orig_indices) {
- return ConvertIndicesBMPUTF8(utf8_str, orig_indices, /*from_utf8=*/false);
+ CodepointSpan bmp_indices) {
+ return ConvertIndicesBMPUTF8(utf8_str, bmp_indices, /*from_utf8=*/false);
}
CodepointSpan ConvertIndicesUTF8ToBMP(const std::string& utf8_str,
- CodepointSpan orig_indices) {
- return ConvertIndicesBMPUTF8(utf8_str, orig_indices, /*from_utf8=*/true);
+ CodepointSpan utf8_indices) {
+ return ConvertIndicesBMPUTF8(utf8_str, utf8_indices, /*from_utf8=*/true);
}
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
-using libtextclassifier::CodepointSpan;
-using libtextclassifier::ConvertIndicesBMPToUTF8;
-using libtextclassifier::ConvertIndicesUTF8ToBMP;
-using libtextclassifier::ScopedLocalRef;
+using libtextclassifier2::CodepointSpan;
+using libtextclassifier2::ConvertIndicesBMPToUTF8;
+using libtextclassifier2::ConvertIndicesUTF8ToBMP;
+using libtextclassifier2::ScopedLocalRef;
-JNI_METHOD(jlong, SmartSelection, nativeNew)
+JNI_METHOD(jlong, TC_CLASS_NAME, nativeNew)
(JNIEnv* env, jobject thiz, jint fd) {
- TextClassificationModel* model = new TextClassificationModel(fd);
- return reinterpret_cast<jlong>(model);
+ return reinterpret_cast<jlong>(
+ TextClassifier::FromFileDescriptor(fd).release());
}
-JNI_METHOD(jlong, SmartSelection, nativeNewFromPath)
+JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromPath)
(JNIEnv* env, jobject thiz, jstring path) {
const std::string path_str = ToStlString(env, path);
- TextClassificationModel* model = new TextClassificationModel(path_str);
- return reinterpret_cast<jlong>(model);
+ return reinterpret_cast<jlong>(TextClassifier::FromPath(path_str).release());
}
-JNI_METHOD(jlong, SmartSelection, nativeNewFromAssetFileDescriptor)
+JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromAssetFileDescriptor)
(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
// Get system-level file descriptor from AssetFileDescriptor.
ScopedLocalRef<jclass> afd_class(
env->FindClass("android/content/res/AssetFileDescriptor"), env);
if (afd_class == nullptr) {
- TC_LOG(ERROR) << "Couln't find AssetFileDescriptor.";
+ TC_LOG(ERROR) << "Couldn't find AssetFileDescriptor.";
return reinterpret_cast<jlong>(nullptr);
}
jmethodID afd_class_getFileDescriptor = env->GetMethodID(
afd_class.get(), "getFileDescriptor", "()Ljava/io/FileDescriptor;");
if (afd_class_getFileDescriptor == nullptr) {
- TC_LOG(ERROR) << "Couln't find getFileDescriptor.";
+ TC_LOG(ERROR) << "Couldn't find getFileDescriptor.";
return reinterpret_cast<jlong>(nullptr);
}
ScopedLocalRef<jclass> fd_class(env->FindClass("java/io/FileDescriptor"),
env);
if (fd_class == nullptr) {
- TC_LOG(ERROR) << "Couln't find FileDescriptor.";
+ TC_LOG(ERROR) << "Couldn't find FileDescriptor.";
return reinterpret_cast<jlong>(nullptr);
}
jfieldID fd_class_descriptor =
env->GetFieldID(fd_class.get(), "descriptor", "I");
if (fd_class_descriptor == nullptr) {
- TC_LOG(ERROR) << "Couln't find descriptor.";
+ TC_LOG(ERROR) << "Couldn't find descriptor.";
return reinterpret_cast<jlong>(nullptr);
}
jobject bundle_jfd = env->CallObjectMethod(afd, afd_class_getFileDescriptor);
jint bundle_cfd = env->GetIntField(bundle_jfd, fd_class_descriptor);
- TextClassificationModel* model =
- new TextClassificationModel(bundle_cfd, offset, size);
- return reinterpret_cast<jlong>(model);
+ return reinterpret_cast<jlong>(
+ TextClassifier::FromFileDescriptor(bundle_cfd, offset, size).release());
}
-JNI_METHOD(jintArray, SmartSelection, nativeSuggest)
+JNI_METHOD(jintArray, TC_CLASS_NAME, nativeSuggest)
(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
jint selection_end) {
- TextClassificationModel* model =
- reinterpret_cast<TextClassificationModel*>(ptr);
+ if (!ptr) {
+ return nullptr;
+ }
+
+ TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr);
const std::string context_utf8 = ToStlString(env, context);
CodepointSpan input_indices =
@@ -240,39 +241,42 @@
return result;
}
-JNI_METHOD(jobjectArray, SmartSelection, nativeClassifyText)
+JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeClassifyText)
(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
- jint selection_end, jint input_flags) {
- TextClassificationModel* ff_model =
- reinterpret_cast<TextClassificationModel*>(ptr);
+ jint selection_end) {
+ if (!ptr) {
+ return nullptr;
+ }
+ TextClassifier* ff_model = reinterpret_cast<TextClassifier*>(ptr);
const std::vector<std::pair<std::string, float>> classification_result =
ff_model->ClassifyText(ToStlString(env, context),
- {selection_begin, selection_end}, input_flags);
+ {selection_begin, selection_end});
return ScoredStringsToJObjectArray(
- env, TC_PACKAGE_PATH "SmartSelection$ClassificationResult",
+ env, TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationResult",
classification_result);
}
-JNI_METHOD(jobjectArray, SmartSelection, nativeAnnotate)
+JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeAnnotate)
(JNIEnv* env, jobject thiz, jlong ptr, jstring context) {
- TextClassificationModel* model =
- reinterpret_cast<TextClassificationModel*>(ptr);
+ if (!ptr) {
+ return nullptr;
+ }
+ TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr);
std::string context_utf8 = ToStlString(env, context);
- std::vector<TextClassificationModel::AnnotatedSpan> annotations =
- model->Annotate(context_utf8);
+ std::vector<AnnotatedSpan> annotations = model->Annotate(context_utf8);
jclass result_class =
- env->FindClass(TC_PACKAGE_PATH "SmartSelection$AnnotatedSpan");
+ env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotatedSpan");
if (!result_class) {
TC_LOG(ERROR) << "Couldn't find result class: "
- << TC_PACKAGE_PATH "SmartSelection$AnnotatedSpan";
+ << TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotatedSpan";
return nullptr;
}
jmethodID result_class_constructor = env->GetMethodID(
result_class, "<init>",
- "(II[L" TC_PACKAGE_PATH "SmartSelection$ClassificationResult;)V");
+ "(II[L" TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationResult;)V");
jobjectArray results =
env->NewObjectArray(annotations.size(), result_class, nullptr);
@@ -284,7 +288,7 @@
result_class, result_class_constructor,
static_cast<jint>(span_bmp.first), static_cast<jint>(span_bmp.second),
ScoredStringsToJObjectArray(
- env, TC_PACKAGE_PATH "SmartSelection$ClassificationResult",
+ env, TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationResult",
annotations[i].classification));
env->SetObjectArrayElement(results, i, result);
env->DeleteLocalRef(result);
@@ -293,58 +297,38 @@
return results;
}
-JNI_METHOD(void, SmartSelection, nativeClose)
+JNI_METHOD(void, TC_CLASS_NAME, nativeClose)
(JNIEnv* env, jobject thiz, jlong ptr) {
- TextClassificationModel* model =
- reinterpret_cast<TextClassificationModel*>(ptr);
+ TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr);
delete model;
}
-JNI_METHOD(jstring, SmartSelection, nativeGetLanguage)
+JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLanguage)
(JNIEnv* env, jobject clazz, jint fd) {
- ModelOptions model_options;
- if (ReadSelectionModelOptions(fd, &model_options)) {
- return env->NewStringUTF(model_options.language().c_str());
- } else {
- return env->NewStringUTF("UNK");
+ std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
+ new libtextclassifier2::ScopedMmap(fd));
+ if (!mmap->handle().ok()) {
+ return env->NewStringUTF("");
}
-}
-
-JNI_METHOD(jint, SmartSelection, nativeGetVersion)
-(JNIEnv* env, jobject clazz, jint fd) {
- ModelOptions model_options;
- if (ReadSelectionModelOptions(fd, &model_options)) {
- return model_options.version();
- } else {
- return -1;
+ const Model* model = libtextclassifier2::ViewModel(
+ mmap->handle().start(), mmap->handle().num_bytes());
+ if (!model || !model->language()) {
+ return env->NewStringUTF("");
}
+ return env->NewStringUTF(model->language()->c_str());
}
-#ifndef LIBTEXTCLASSIFIER_DISABLE_LANG_ID
-JNI_METHOD(jlong, LangId, nativeNew)
-(JNIEnv* env, jobject thiz, jint fd) {
- return reinterpret_cast<jlong>(new LangId(fd));
-}
-
-JNI_METHOD(jobjectArray, LangId, nativeFindLanguages)
-(JNIEnv* env, jobject thiz, jlong ptr, jstring text) {
- LangId* lang_id = reinterpret_cast<LangId*>(ptr);
- const std::vector<std::pair<std::string, float>> scored_languages =
- lang_id->FindLanguages(ToStlString(env, text));
-
- return ScoredStringsToJObjectArray(
- env, TC_PACKAGE_PATH "LangId$ClassificationResult", scored_languages);
-}
-
-JNI_METHOD(void, LangId, nativeClose)
-(JNIEnv* env, jobject thiz, jlong ptr) {
- LangId* lang_id = reinterpret_cast<LangId*>(ptr);
- delete lang_id;
-}
-
-JNI_METHOD(int, LangId, nativeGetVersion)
+JNI_METHOD(jint, TC_CLASS_NAME, nativeGetVersion)
(JNIEnv* env, jobject clazz, jint fd) {
- std::unique_ptr<LangId> lang_id(new LangId(fd));
- return lang_id->version();
+ std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
+ new libtextclassifier2::ScopedMmap(fd));
+ if (!mmap->handle().ok()) {
+ return 0;
+ }
+ const Model* model = libtextclassifier2::ViewModel(
+ mmap->handle().start(), mmap->handle().num_bytes());
+ if (!model) {
+ return 0;
+ }
+ return model->version();
}
-#endif
diff --git a/textclassifier_jni.h b/textclassifier_jni.h
index 1709ff4..1f64fff 100644
--- a/textclassifier_jni.h
+++ b/textclassifier_jni.h
@@ -14,17 +14,28 @@
* limitations under the License.
*/
-#ifndef LIBTEXTCLASSIFIER_TEXTCLASSIFIER_JNI_H_
-#define LIBTEXTCLASSIFIER_TEXTCLASSIFIER_JNI_H_
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TEXTCLASSIFIER_JNI_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TEXTCLASSIFIER_JNI_H_
#include <jni.h>
#include <string>
-#include "smartselect/types.h"
+#include "types.h"
+
+// When we use a macro as an argument for a macro, an additional level of
+// indirection is needed, if the macro argument is used with # or ##.
+#define ADD_QUOTES_HELPER(TOKEN) #TOKEN
+#define ADD_QUOTES(TOKEN) ADD_QUOTES_HELPER(TOKEN)
#ifndef TC_PACKAGE_NAME
#define TC_PACKAGE_NAME android_view_textclassifier
#endif
+
+#ifndef TC_CLASS_NAME
+#define TC_CLASS_NAME SmartSelection
+#endif
+#define TC_CLASS_NAME_STR ADD_QUOTES(TC_CLASS_NAME)
+
#ifndef TC_PACKAGE_PATH
#define TC_PACKAGE_PATH "android/view/textclassifier/"
#endif
@@ -35,6 +46,7 @@
Java_##package_name##_##class_name##_##method_name
// The indirection is needed to correctly expand the TC_PACKAGE_NAME macro.
+// See the explanation near ADD_QUOTES macro.
#define JNI_METHOD2(return_type, package_name, class_name, method_name) \
JNI_METHOD_PRIMITIVE(return_type, package_name, class_name, method_name)
@@ -46,63 +58,52 @@
#endif
// SmartSelection.
-JNI_METHOD(jlong, SmartSelection, nativeNew)
+JNI_METHOD(jlong, TC_CLASS_NAME, nativeNew)
(JNIEnv* env, jobject thiz, jint fd);
-JNI_METHOD(jlong, SmartSelection, nativeNewFromPath)
+JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromPath)
(JNIEnv* env, jobject thiz, jstring path);
-JNI_METHOD(jlong, SmartSelection, nativeNewFromAssetFileDescriptor)
+JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromAssetFileDescriptor)
(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size);
-JNI_METHOD(jintArray, SmartSelection, nativeSuggest)
+JNI_METHOD(jintArray, TC_CLASS_NAME, nativeSuggest)
(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
jint selection_end);
-JNI_METHOD(jobjectArray, SmartSelection, nativeClassifyText)
+JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeClassifyText)
(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
- jint selection_end, jint input_flags);
+ jint selection_end);
-JNI_METHOD(jobjectArray, SmartSelection, nativeAnnotate)
+JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeAnnotate)
(JNIEnv* env, jobject thiz, jlong ptr, jstring context);
-JNI_METHOD(void, SmartSelection, nativeClose)
+JNI_METHOD(void, TC_CLASS_NAME, nativeClose)
(JNIEnv* env, jobject thiz, jlong ptr);
-JNI_METHOD(jstring, SmartSelection, nativeGetLanguage)
+JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLanguage)
(JNIEnv* env, jobject clazz, jint fd);
-JNI_METHOD(jint, SmartSelection, nativeGetVersion)
+JNI_METHOD(jint, TC_CLASS_NAME, nativeGetVersion)
(JNIEnv* env, jobject clazz, jint fd);
-#ifndef LIBTEXTCLASSIFIER_DISABLE_LANG_ID
-// LangId.
-JNI_METHOD(jlong, LangId, nativeNew)(JNIEnv* env, jobject thiz, jint fd);
-
-JNI_METHOD(jobjectArray, LangId, nativeFindLanguages)
-(JNIEnv* env, jobject thiz, jlong ptr, jstring text);
-
-JNI_METHOD(void, LangId, nativeClose)(JNIEnv* env, jobject thiz, jlong ptr);
-
-JNI_METHOD(int, LangId, nativeGetVersion)(JNIEnv* env, jobject clazz, jint fd);
-#endif
-
#ifdef __cplusplus
}
#endif
-namespace libtextclassifier {
+namespace libtextclassifier2 {
// Given a utf8 string and a span expressed in Java BMP (basic multilingual
// plane) codepoints, converts it to a span expressed in utf8 codepoints.
-libtextclassifier::CodepointSpan ConvertIndicesBMPToUTF8(
- const std::string& utf8_str, libtextclassifier::CodepointSpan bmp_indices);
+libtextclassifier2::CodepointSpan ConvertIndicesBMPToUTF8(
+ const std::string& utf8_str, libtextclassifier2::CodepointSpan bmp_indices);
// Given a utf8 string and a span expressed in utf8 codepoints, converts it to a
// span expressed in Java BMP (basic multilingual plane) codepoints.
-libtextclassifier::CodepointSpan ConvertIndicesUTF8ToBMP(
- const std::string& utf8_str, libtextclassifier::CodepointSpan utf8_indices);
+libtextclassifier2::CodepointSpan ConvertIndicesUTF8ToBMP(
+ const std::string& utf8_str,
+ libtextclassifier2::CodepointSpan utf8_indices);
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
-#endif // LIBTEXTCLASSIFIER_TEXTCLASSIFIER_JNI_H_
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TEXTCLASSIFIER_JNI_H_
diff --git a/textclassifier_jni_test.cc b/textclassifier_jni_test.cc
index ffc193b..87b96fa 100644
--- a/textclassifier_jni_test.cc
+++ b/textclassifier_jni_test.cc
@@ -19,7 +19,7 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
-namespace libtextclassifier {
+namespace libtextclassifier2 {
namespace {
TEST(TextClassifier, ConvertIndicesBMPUTF8) {
@@ -76,4 +76,4 @@
}
} // namespace
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
diff --git a/smartselect/token-feature-extractor.cc b/token-feature-extractor.cc
similarity index 74%
rename from smartselect/token-feature-extractor.cc
rename to token-feature-extractor.cc
index 6afd951..33c4d75 100644
--- a/smartselect/token-feature-extractor.cc
+++ b/token-feature-extractor.cc
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-#include "smartselect/token-feature-extractor.h"
+#include "token-feature-extractor.h"
#include <cctype>
#include <string>
@@ -23,12 +23,8 @@
#include "util/hash/farmhash.h"
#include "util/strings/stringpiece.h"
#include "util/utf8/unicodetext.h"
-#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
-#include "unicode/regex.h"
-#include "unicode/uchar.h"
-#endif
-namespace libtextclassifier {
+namespace libtextclassifier2 {
namespace {
@@ -50,69 +46,41 @@
return copy;
}
-#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
void RemapTokenUnicode(const std::string& token,
const TokenFeatureExtractorOptions& options,
- UnicodeText* remapped) {
+ const UniLib& unilib, UnicodeText* remapped) {
if (!options.remap_digits && !options.lowercase_tokens) {
// Leave remapped untouched.
return;
}
UnicodeText word = UTF8ToUnicodeText(token, /*do_copy=*/false);
- icu::UnicodeString icu_string;
+ remapped->clear();
for (auto it = word.begin(); it != word.end(); ++it) {
- if (options.remap_digits && u_isdigit(*it)) {
- icu_string.append('0');
+ if (options.remap_digits && unilib.IsDigit(*it)) {
+ remapped->AppendCodepoint('0');
} else if (options.lowercase_tokens) {
- icu_string.append(u_tolower(*it));
+ remapped->AppendCodepoint(unilib.ToLower(*it));
} else {
- icu_string.append(*it);
+ remapped->AppendCodepoint(*it);
}
}
- std::string utf8_str;
- icu_string.toUTF8String(utf8_str);
- remapped->CopyUTF8(utf8_str.data(), utf8_str.length());
}
-#endif
} // namespace
TokenFeatureExtractor::TokenFeatureExtractor(
- const TokenFeatureExtractorOptions& options)
- : options_(options) {
-#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
- UErrorCode status;
+ const TokenFeatureExtractorOptions& options, const UniLib& unilib)
+ : options_(options), unilib_(unilib) {
for (const std::string& pattern : options.regexp_features) {
- status = U_ZERO_ERROR;
- regex_patterns_.push_back(
- std::unique_ptr<icu::RegexPattern>(icu::RegexPattern::compile(
- icu::UnicodeString(pattern.c_str(), pattern.size(), "utf-8"), 0,
- status)));
- if (U_FAILURE(status)) {
- TC_LOG(WARNING) << "Failed to load pattern" << pattern;
- }
+ regex_patterns_.push_back(std::unique_ptr<UniLib::RegexPattern>(
+ unilib_.CreateRegexPattern(pattern)));
}
-#else
- bool found_unsupported_regexp_features = false;
- for (const std::string& pattern : options.regexp_features) {
- // A temporary solution to support this specific regexp pattern without
- // adding too much binary size.
- if (pattern == "^[^a-z]*$") {
- enable_all_caps_feature_ = true;
- } else {
- found_unsupported_regexp_features = true;
- }
- }
- if (found_unsupported_regexp_features) {
- TC_LOG(WARNING) << "ICU not supported regexp features ignored.";
- }
-#endif
}
int TokenFeatureExtractor::HashToken(StringPiece token) const {
if (options_.allowed_chargrams.empty()) {
- return tcfarmhash::Fingerprint64(token) % options_.num_buckets;
+ return tc2farmhash::Fingerprint64(token) % options_.num_buckets;
} else {
// Padding and out-of-vocabulary tokens have extra buckets reserved because
// they are special and important tokens, and we don't want them to share
@@ -126,7 +94,7 @@
options_.allowed_chargrams.end()) {
return 0; // Out-of-vocabulary.
} else {
- return (tcfarmhash::Fingerprint64(token) %
+ return (tc2farmhash::Fingerprint64(token) %
(options_.num_buckets - kNumExtraBuckets)) +
kNumExtraBuckets;
}
@@ -192,13 +160,12 @@
std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesUnicode(
const Token& token) const {
-#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
std::vector<int> result;
if (token.is_padding || token.value.empty()) {
result.push_back(HashToken("<PAD>"));
} else {
UnicodeText word = UTF8ToUnicodeText(token.value, /*do_copy=*/false);
- RemapTokenUnicode(token.value, options_, &word);
+ RemapTokenUnicode(token.value, options_, unilib_, &word);
// Trim the word if needed by finding a left-cut point and right-cut point.
auto left_cut = word.begin();
@@ -268,10 +235,6 @@
}
}
return result;
-#else
- TC_LOG(WARNING) << "ICU not supported. No feature extracted.";
- return {};
-#endif
}
bool TokenFeatureExtractor::Extract(const Token& token, bool is_in_span,
@@ -287,13 +250,7 @@
if (options_.unicode_aware_features) {
UnicodeText token_unicode =
UTF8ToUnicodeText(token.value, /*do_copy=*/false);
- bool is_upper;
-#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
- is_upper = u_isupper(*token_unicode.begin());
-#else
- TC_LOG(WARNING) << "Using non-unicode isupper because ICU is disabled.";
- is_upper = isupper(*token_unicode.begin());
-#endif
+ const bool is_upper = unilib_.IsUpper(*token_unicode.begin());
if (!token.value.empty() && is_upper) {
dense_features->push_back(1.0);
} else {
@@ -320,46 +277,23 @@
}
}
-#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
// Add regexp features.
if (!regex_patterns_.empty()) {
- icu::UnicodeString unicode_str(token.value.c_str(), token.value.size(),
- "utf-8");
for (int i = 0; i < regex_patterns_.size(); ++i) {
if (!regex_patterns_[i].get()) {
dense_features->push_back(-1.0);
continue;
}
- // Check for match.
- UErrorCode status = U_ZERO_ERROR;
- std::unique_ptr<icu::RegexMatcher> matcher(
- regex_patterns_[i]->matcher(unicode_str, status));
- if (matcher->find()) {
+ if (regex_patterns_[i]->Matches(token.value)) {
dense_features->push_back(1.0);
} else {
dense_features->push_back(-1.0);
}
}
}
-#else
- if (enable_all_caps_feature_) {
- bool is_all_caps = true;
- for (const char character_byte : token.value) {
- if (islower(character_byte)) {
- is_all_caps = false;
- break;
- }
- }
- if (is_all_caps) {
- dense_features->push_back(1.0);
- } else {
- dense_features->push_back(-1.0);
- }
- }
-#endif
return true;
}
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
diff --git a/smartselect/token-feature-extractor.h b/token-feature-extractor.h
similarity index 81%
rename from smartselect/token-feature-extractor.h
rename to token-feature-extractor.h
index 5afeca4..9d476ba 100644
--- a/smartselect/token-feature-extractor.h
+++ b/token-feature-extractor.h
@@ -14,20 +14,18 @@
* limitations under the License.
*/
-#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_TOKEN_FEATURE_EXTRACTOR_H_
-#define LIBTEXTCLASSIFIER_SMARTSELECT_TOKEN_FEATURE_EXTRACTOR_H_
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TOKEN_FEATURE_EXTRACTOR_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TOKEN_FEATURE_EXTRACTOR_H_
#include <memory>
#include <unordered_set>
#include <vector>
-#include "smartselect/types.h"
+#include "types.h"
#include "util/strings/stringpiece.h"
-#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
-#include "unicode/regex.h"
-#endif
+#include "util/utf8/unilib.h"
-namespace libtextclassifier {
+namespace libtextclassifier2 {
struct TokenFeatureExtractorOptions {
// Number of buckets used for hashing charactergrams.
@@ -67,7 +65,8 @@
class TokenFeatureExtractor {
public:
- explicit TokenFeatureExtractor(const TokenFeatureExtractorOptions& options);
+ TokenFeatureExtractor(const TokenFeatureExtractorOptions& options,
+ const UniLib& unilib);
// Extracts features from a token.
// - is_in_span is a bool indicator whether the token is a part of the
@@ -83,13 +82,7 @@
int DenseFeaturesCount() const {
int feature_count =
options_.extract_case_feature + options_.extract_selection_mask_feature;
-#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
feature_count += regex_patterns_.size();
-#else
- if (enable_all_caps_feature_) {
- feature_count += 1;
- }
-#endif
return feature_count;
}
@@ -110,13 +103,10 @@
private:
TokenFeatureExtractorOptions options_;
-#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
- std::vector<std::unique_ptr<icu::RegexPattern>> regex_patterns_;
-#else
- bool enable_all_caps_feature_ = false;
-#endif
+ std::vector<std::unique_ptr<UniLib::RegexPattern>> regex_patterns_;
+ const UniLib& unilib_;
};
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
-#endif // LIBTEXTCLASSIFIER_SMARTSELECT_TOKEN_FEATURE_EXTRACTOR_H_
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TOKEN_FEATURE_EXTRACTOR_H_
diff --git a/smartselect/token-feature-extractor_test.cc b/token-feature-extractor_test.cc
similarity index 92%
rename from smartselect/token-feature-extractor_test.cc
rename to token-feature-extractor_test.cc
index 4b635fd..d6e48bb 100644
--- a/smartselect/token-feature-extractor_test.cc
+++ b/token-feature-extractor_test.cc
@@ -14,18 +14,18 @@
* limitations under the License.
*/
-#include "smartselect/token-feature-extractor.h"
+#include "token-feature-extractor.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
-namespace libtextclassifier {
+namespace libtextclassifier2 {
namespace {
class TestingTokenFeatureExtractor : public TokenFeatureExtractor {
public:
- using TokenFeatureExtractor::TokenFeatureExtractor;
using TokenFeatureExtractor::HashToken;
+ using TokenFeatureExtractor::TokenFeatureExtractor;
};
TEST(TokenFeatureExtractorTest, ExtractAscii) {
@@ -35,7 +35,8 @@
options.extract_case_feature = true;
options.unicode_aware_features = false;
options.extract_selection_mask_feature = true;
- TestingTokenFeatureExtractor extractor(options);
+ const UniLib unilib;
+ TestingTokenFeatureExtractor extractor(options, unilib);
std::vector<int> sparse_features;
std::vector<float> dense_features;
@@ -105,7 +106,8 @@
options.extract_case_feature = true;
options.unicode_aware_features = false;
options.extract_selection_mask_feature = true;
- TestingTokenFeatureExtractor extractor(options);
+ const UniLib unilib;
+ TestingTokenFeatureExtractor extractor(options, unilib);
std::vector<int> sparse_features;
std::vector<float> dense_features;
@@ -134,7 +136,8 @@
options.extract_case_feature = true;
options.unicode_aware_features = true;
options.extract_selection_mask_feature = true;
- TestingTokenFeatureExtractor extractor(options);
+ const UniLib unilib;
+ TestingTokenFeatureExtractor extractor(options, unilib);
std::vector<int> sparse_features;
std::vector<float> dense_features;
@@ -204,7 +207,8 @@
options.extract_case_feature = true;
options.unicode_aware_features = true;
options.extract_selection_mask_feature = true;
- TestingTokenFeatureExtractor extractor(options);
+ const UniLib unilib;
+ TestingTokenFeatureExtractor extractor(options, unilib);
std::vector<int> sparse_features;
std::vector<float> dense_features;
@@ -227,6 +231,7 @@
EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
}
+#ifdef LIBTEXTCLASSIFIER_TEST_ICU
TEST(TokenFeatureExtractorTest, ICUCaseFeature) {
TokenFeatureExtractorOptions options;
options.num_buckets = 1000;
@@ -234,7 +239,8 @@
options.extract_case_feature = true;
options.unicode_aware_features = true;
options.extract_selection_mask_feature = false;
- TokenFeatureExtractor extractor(options);
+ const UniLib unilib;
+ TestingTokenFeatureExtractor extractor(options, unilib);
std::vector<int> sparse_features;
std::vector<float> dense_features;
@@ -260,6 +266,7 @@
&dense_features);
EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
}
+#endif
TEST(TokenFeatureExtractorTest, DigitRemapping) {
TokenFeatureExtractorOptions options;
@@ -267,7 +274,8 @@
options.chargram_orders = std::vector<int>{1, 2};
options.remap_digits = true;
options.unicode_aware_features = false;
- TokenFeatureExtractor extractor(options);
+ const UniLib unilib;
+ TestingTokenFeatureExtractor extractor(options, unilib);
std::vector<int> sparse_features;
std::vector<float> dense_features;
@@ -291,7 +299,8 @@
options.chargram_orders = std::vector<int>{1, 2};
options.remap_digits = true;
options.unicode_aware_features = true;
- TokenFeatureExtractor extractor(options);
+ const UniLib unilib;
+ TestingTokenFeatureExtractor extractor(options, unilib);
std::vector<int> sparse_features;
std::vector<float> dense_features;
@@ -315,7 +324,8 @@
options.chargram_orders = std::vector<int>{1, 2};
options.lowercase_tokens = true;
options.unicode_aware_features = false;
- TokenFeatureExtractor extractor(options);
+ const UniLib unilib;
+ TestingTokenFeatureExtractor extractor(options, unilib);
std::vector<int> sparse_features;
std::vector<float> dense_features;
@@ -332,13 +342,15 @@
EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
}
+#ifdef LIBTEXTCLASSIFIER_TEST_ICU
TEST(TokenFeatureExtractorTest, LowercaseUnicode) {
TokenFeatureExtractorOptions options;
options.num_buckets = 1000;
options.chargram_orders = std::vector<int>{1, 2};
options.lowercase_tokens = true;
options.unicode_aware_features = true;
- TokenFeatureExtractor extractor(options);
+ const UniLib unilib;
+ TestingTokenFeatureExtractor extractor(options, unilib);
std::vector<int> sparse_features;
std::vector<float> dense_features;
@@ -349,7 +361,9 @@
&dense_features);
EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
}
+#endif
+#ifdef LIBTEXTCLASSIFIER_TEST_ICU
TEST(TokenFeatureExtractorTest, RegexFeatures) {
TokenFeatureExtractorOptions options;
options.num_buckets = 1000;
@@ -358,7 +372,8 @@
options.unicode_aware_features = false;
options.regexp_features.push_back("^[a-z]+$"); // all lower case.
options.regexp_features.push_back("^[0-9]+$"); // all digits.
- TokenFeatureExtractor extractor(options);
+ const UniLib unilib;
+ TestingTokenFeatureExtractor extractor(options, unilib);
std::vector<int> sparse_features;
std::vector<float> dense_features;
@@ -381,6 +396,7 @@
&dense_features);
EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 1.0}));
}
+#endif
TEST(TokenFeatureExtractorTest, ExtractTooLongWord) {
TokenFeatureExtractorOptions options;
@@ -389,7 +405,8 @@
options.extract_case_feature = true;
options.unicode_aware_features = true;
options.extract_selection_mask_feature = true;
- TestingTokenFeatureExtractor extractor(options);
+ const UniLib unilib;
+ TestingTokenFeatureExtractor extractor(options, unilib);
// Test that this runs. ASAN should catch problems.
std::vector<int> sparse_features;
@@ -413,10 +430,12 @@
options.extract_case_feature = true;
options.unicode_aware_features = true;
options.extract_selection_mask_feature = true;
- TestingTokenFeatureExtractor extractor_unicode(options);
+
+ UniLib unilib;
+ TestingTokenFeatureExtractor extractor_unicode(options, unilib);
options.unicode_aware_features = false;
- TestingTokenFeatureExtractor extractor_ascii(options);
+ TestingTokenFeatureExtractor extractor_ascii(options, unilib);
for (const std::string& input :
{"https://www.abcdefgh.com/in/xxxkkkvayio",
@@ -447,7 +466,8 @@
options.unicode_aware_features = false;
options.extract_selection_mask_feature = true;
- TestingTokenFeatureExtractor extractor(options);
+ const UniLib unilib;
+ TestingTokenFeatureExtractor extractor(options, unilib);
std::vector<int> sparse_features;
std::vector<float> dense_features;
@@ -473,7 +493,8 @@
options.allowed_chargrams.insert("!");
options.allowed_chargrams.insert("\xc4"); // UTF8 control character.
- TestingTokenFeatureExtractor extractor(options);
+ const UniLib unilib;
+ TestingTokenFeatureExtractor extractor(options, unilib);
std::vector<int> sparse_features;
std::vector<float> dense_features;
@@ -540,4 +561,4 @@
}
} // namespace
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
diff --git a/tokenizer.cc b/tokenizer.cc
new file mode 100644
index 0000000..456826d
--- /dev/null
+++ b/tokenizer.cc
@@ -0,0 +1,120 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "tokenizer.h"
+
+#include <algorithm>
+
+#include "util/base/logging.h"
+#include "util/strings/utf8.h"
+#include "util/utf8/unicodetext.h"
+
+namespace libtextclassifier2 {
+
+Tokenizer::Tokenizer(
+ const std::vector<const TokenizationCodepointRange*>& codepoint_ranges,
+ bool split_on_script_change)
+ : codepoint_ranges_(codepoint_ranges),
+ split_on_script_change_(split_on_script_change) {
+ std::sort(codepoint_ranges_.begin(), codepoint_ranges_.end(),
+ [](const TokenizationCodepointRange* a,
+ const TokenizationCodepointRange* b) {
+ return a->start() < b->start();
+ });
+}
+
+const TokenizationCodepointRange* Tokenizer::FindTokenizationRange(
+ int codepoint) const {
+ auto it = std::lower_bound(
+ codepoint_ranges_.begin(), codepoint_ranges_.end(), codepoint,
+ [](const TokenizationCodepointRange* range, int codepoint) {
+ // This function compares range with the codepoint for the purpose of
+ // finding the first greater or equal range. Because of the use of
+ // std::lower_bound it needs to return true when range < codepoint;
+ // the first time it will return false the lower bound is found and
+ // returned.
+ //
+ // It might seem weird that the condition is range.end <= codepoint
+ // here but when codepoint == range.end it means it's actually just
+ // outside of the range, thus the range is less than the codepoint.
+ return range->end() <= codepoint;
+ });
+ if (it != codepoint_ranges_.end() && (*it)->start() <= codepoint &&
+ (*it)->end() > codepoint) {
+ return *it;
+ } else {
+ return nullptr;
+ }
+}
+
+void Tokenizer::GetScriptAndRole(char32 codepoint,
+ TokenizationCodepointRange_::Role* role,
+ int* script) const {
+ const TokenizationCodepointRange* range = FindTokenizationRange(codepoint);
+ if (range) {
+ *role = range->role();
+ *script = range->script_id();
+ } else {
+ *role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
+ *script = kUnknownScript;
+ }
+}
+
+std::vector<Token> Tokenizer::Tokenize(const std::string& utf8_text) const {
+ UnicodeText context_unicode = UTF8ToUnicodeText(utf8_text, /*do_copy=*/false);
+
+ std::vector<Token> result;
+ Token new_token("", 0, 0);
+ int codepoint_index = 0;
+
+ int last_script = kInvalidScript;
+ for (auto it = context_unicode.begin(); it != context_unicode.end();
+ ++it, ++codepoint_index) {
+ TokenizationCodepointRange_::Role role;
+ int script;
+ GetScriptAndRole(*it, &role, &script);
+
+ if (role & TokenizationCodepointRange_::Role_SPLIT_BEFORE ||
+ (split_on_script_change_ && last_script != kInvalidScript &&
+ last_script != script)) {
+ if (!new_token.value.empty()) {
+ result.push_back(new_token);
+ }
+ new_token = Token("", codepoint_index, codepoint_index);
+ }
+ if (!(role & TokenizationCodepointRange_::Role_DISCARD_CODEPOINT)) {
+ new_token.value += std::string(
+ it.utf8_data(),
+ it.utf8_data() + GetNumBytesForNonZeroUTF8Char(it.utf8_data()));
+ ++new_token.end;
+ }
+ if (role & TokenizationCodepointRange_::Role_SPLIT_AFTER) {
+ if (!new_token.value.empty()) {
+ result.push_back(new_token);
+ }
+ new_token = Token("", codepoint_index + 1, codepoint_index + 1);
+ }
+
+ last_script = script;
+ }
+ if (!new_token.value.empty()) {
+ result.push_back(new_token);
+ }
+
+ return result;
+}
+
+} // namespace libtextclassifier2
diff --git a/tokenizer.h b/tokenizer.h
new file mode 100644
index 0000000..72a9fbd
--- /dev/null
+++ b/tokenizer.h
@@ -0,0 +1,66 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TOKENIZER_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TOKENIZER_H_
+
+#include <string>
+#include <vector>
+
+#include "model_generated.h"
+#include "types.h"
+#include "util/base/integral_types.h"
+
+namespace libtextclassifier2 {
+
+const int kInvalidScript = -1;
+const int kUnknownScript = -2;
+
+// Tokenizer splits the input string into a sequence of tokens, according to the
+// configuration.
+class Tokenizer {
+ public:
+ explicit Tokenizer(
+ const std::vector<const TokenizationCodepointRange*>& codepoint_ranges,
+ bool split_on_script_change);
+
+ // Tokenizes the input string using the selected tokenization method.
+ std::vector<Token> Tokenize(const std::string& utf8_text) const;
+
+ protected:
+ // Finds the tokenization codepoint range config for given codepoint.
+ // Internally uses binary search so should be O(log(# of codepoint_ranges)).
+ const TokenizationCodepointRange* FindTokenizationRange(int codepoint) const;
+
+ // Finds the role and script for given codepoint. If not found, DEFAULT_ROLE
+ // and kUnknownScript are assigned.
+ void GetScriptAndRole(char32 codepoint,
+ TokenizationCodepointRange_::Role* role,
+ int* script) const;
+
+ private:
+ // Codepoint ranges that determine how different codepoints are tokenized.
+ // The ranges must not overlap.
+ std::vector<const TokenizationCodepointRange*> codepoint_ranges_;
+
+ // If true, tokens will be additionally split when the codepoint's script_id
+ // changes.
+ bool split_on_script_change_;
+};
+
+} // namespace libtextclassifier2
+
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TOKENIZER_H_
diff --git a/tokenizer_test.cc b/tokenizer_test.cc
new file mode 100644
index 0000000..d9a0dea
--- /dev/null
+++ b/tokenizer_test.cc
@@ -0,0 +1,334 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "tokenizer.h"
+
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier2 {
+namespace {
+
+using testing::ElementsAreArray;
+
+class TestingTokenizer : public Tokenizer {
+ public:
+ explicit TestingTokenizer(
+ const std::vector<const TokenizationCodepointRange*>&
+ codepoint_range_configs,
+ bool split_on_script_change)
+ : Tokenizer(codepoint_range_configs, split_on_script_change) {}
+
+ using Tokenizer::FindTokenizationRange;
+};
+
+class TestingTokenizerProxy {
+ public:
+ explicit TestingTokenizerProxy(
+ const std::vector<TokenizationCodepointRangeT>& codepoint_range_configs,
+ bool split_on_script_change) {
+ int num_configs = codepoint_range_configs.size();
+ std::vector<const TokenizationCodepointRange*> configs_fb;
+ buffers_.reserve(num_configs);
+ for (int i = 0; i < num_configs; i++) {
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(CreateTokenizationCodepointRange(
+ builder, &codepoint_range_configs[i]));
+ buffers_.push_back(builder.Release());
+ configs_fb.push_back(
+ flatbuffers::GetRoot<TokenizationCodepointRange>(buffers_[i].data()));
+ }
+ tokenizer_ = std::unique_ptr<TestingTokenizer>(
+ new TestingTokenizer(configs_fb, split_on_script_change));
+ }
+
+ TokenizationCodepointRange_::Role TestFindTokenizationRole(int c) const {
+ const TokenizationCodepointRange* range =
+ tokenizer_->FindTokenizationRange(c);
+ if (range != nullptr) {
+ return range->role();
+ } else {
+ return TokenizationCodepointRange_::Role_DEFAULT_ROLE;
+ }
+ }
+
+ std::vector<Token> Tokenize(const std::string& utf8_text) const {
+ return tokenizer_->Tokenize(utf8_text);
+ }
+
+ private:
+ std::vector<flatbuffers::DetachedBuffer> buffers_;
+ std::unique_ptr<TestingTokenizer> tokenizer_;
+};
+
+TEST(TokenizerTest, FindTokenizationRange) {
+ std::vector<TokenizationCodepointRangeT> configs;
+ TokenizationCodepointRangeT* config;
+
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0;
+ config->end = 10;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 1234;
+ config->end = 12345;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+
+ TestingTokenizerProxy tokenizer(configs, /*split_on_script_change=*/false);
+
+ // Test hits to the first group.
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(0),
+ TokenizationCodepointRange_::Role_TOKEN_SEPARATOR);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(5),
+ TokenizationCodepointRange_::Role_TOKEN_SEPARATOR);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(10),
+ TokenizationCodepointRange_::Role_DEFAULT_ROLE);
+
+ // Test a hit to the second group.
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(31),
+ TokenizationCodepointRange_::Role_DEFAULT_ROLE);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(32),
+ TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(33),
+ TokenizationCodepointRange_::Role_DEFAULT_ROLE);
+
+ // Test hits to the third group.
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(1233),
+ TokenizationCodepointRange_::Role_DEFAULT_ROLE);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(1234),
+ TokenizationCodepointRange_::Role_TOKEN_SEPARATOR);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(12344),
+ TokenizationCodepointRange_::Role_TOKEN_SEPARATOR);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(12345),
+ TokenizationCodepointRange_::Role_DEFAULT_ROLE);
+
+ // Test a hit outside.
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(99),
+ TokenizationCodepointRange_::Role_DEFAULT_ROLE);
+}
+
+TEST(TokenizerTest, TokenizeOnSpace) {
+ std::vector<TokenizationCodepointRangeT> configs;
+ TokenizationCodepointRangeT* config;
+
+ configs.emplace_back();
+ config = &configs.back();
+ // Space character.
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+
+ TestingTokenizerProxy tokenizer(configs, /*split_on_script_change=*/false);
+ std::vector<Token> tokens = tokenizer.Tokenize("Hello world!");
+
+ EXPECT_THAT(tokens,
+ ElementsAreArray({Token("Hello", 0, 5), Token("world!", 6, 12)}));
+}
+
+TEST(TokenizerTest, TokenizeOnSpaceAndScriptChange) {
+ std::vector<TokenizationCodepointRangeT> configs;
+ TokenizationCodepointRangeT* config;
+
+ // Latin.
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0;
+ config->end = 32;
+ config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
+ config->script_id = 1;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+ config->script_id = 1;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 33;
+ config->end = 0x77F + 1;
+ config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
+ config->script_id = 1;
+
+ TestingTokenizerProxy tokenizer(configs, /*split_on_script_change=*/true);
+ EXPECT_THAT(tokenizer.Tokenize("앨라배마 주 전화(123) 456-789웹사이트"),
+ std::vector<Token>({Token("앨라배마", 0, 4), Token("주", 5, 6),
+ Token("전화", 7, 10), Token("(123)", 10, 15),
+ Token("456-789", 16, 23),
+ Token("웹사이트", 23, 28)}));
+} // namespace
+
+TEST(TokenizerTest, TokenizeComplex) {
+ std::vector<TokenizationCodepointRangeT> configs;
+ TokenizationCodepointRangeT* config;
+
+ // Source: http://www.unicode.org/Public/10.0.0/ucd/Blocks-10.0.0d1.txt
+ // Latin - cyrilic.
+ // 0000..007F; Basic Latin
+ // 0080..00FF; Latin-1 Supplement
+ // 0100..017F; Latin Extended-A
+ // 0180..024F; Latin Extended-B
+ // 0250..02AF; IPA Extensions
+ // 02B0..02FF; Spacing Modifier Letters
+ // 0300..036F; Combining Diacritical Marks
+ // 0370..03FF; Greek and Coptic
+ // 0400..04FF; Cyrillic
+ // 0500..052F; Cyrillic Supplement
+ // 0530..058F; Armenian
+ // 0590..05FF; Hebrew
+ // 0600..06FF; Arabic
+ // 0700..074F; Syriac
+ // 0750..077F; Arabic Supplement
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0;
+ config->end = 32;
+ config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 33;
+ config->end = 0x77F + 1;
+ config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
+
+ // CJK
+ // 2E80..2EFF; CJK Radicals Supplement
+ // 3000..303F; CJK Symbols and Punctuation
+ // 3040..309F; Hiragana
+ // 30A0..30FF; Katakana
+ // 3100..312F; Bopomofo
+ // 3130..318F; Hangul Compatibility Jamo
+ // 3190..319F; Kanbun
+ // 31A0..31BF; Bopomofo Extended
+ // 31C0..31EF; CJK Strokes
+ // 31F0..31FF; Katakana Phonetic Extensions
+ // 3200..32FF; Enclosed CJK Letters and Months
+ // 3300..33FF; CJK Compatibility
+ // 3400..4DBF; CJK Unified Ideographs Extension A
+ // 4DC0..4DFF; Yijing Hexagram Symbols
+ // 4E00..9FFF; CJK Unified Ideographs
+ // A000..A48F; Yi Syllables
+ // A490..A4CF; Yi Radicals
+ // A4D0..A4FF; Lisu
+ // A500..A63F; Vai
+ // F900..FAFF; CJK Compatibility Ideographs
+ // FE30..FE4F; CJK Compatibility Forms
+ // 20000..2A6DF; CJK Unified Ideographs Extension B
+ // 2A700..2B73F; CJK Unified Ideographs Extension C
+ // 2B740..2B81F; CJK Unified Ideographs Extension D
+ // 2B820..2CEAF; CJK Unified Ideographs Extension E
+ // 2CEB0..2EBEF; CJK Unified Ideographs Extension F
+ // 2F800..2FA1F; CJK Compatibility Ideographs Supplement
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x2E80;
+ config->end = 0x2EFF + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x3000;
+ config->end = 0xA63F + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0xF900;
+ config->end = 0xFAFF + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0xFE30;
+ config->end = 0xFE4F + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x20000;
+ config->end = 0x2A6DF + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x2A700;
+ config->end = 0x2B73F + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x2B740;
+ config->end = 0x2B81F + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x2B820;
+ config->end = 0x2CEAF + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x2CEB0;
+ config->end = 0x2EBEF + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x2F800;
+ config->end = 0x2FA1F + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+
+ // Thai.
+ // 0E00..0E7F; Thai
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x0E00;
+ config->end = 0x0E7F + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+
+ TestingTokenizerProxy tokenizer(configs, /*split_on_script_change=*/false);
+ std::vector<Token> tokens;
+
+ tokens = tokenizer.Tokenize(
+ "問少目木輸走猶術権自京門録球変。細開括省用掲情結傍走愛明氷。");
+ EXPECT_EQ(tokens.size(), 30);
+
+ tokens = tokenizer.Tokenize("問少目 hello 木輸ยามきゃ");
+ // clang-format off
+ EXPECT_THAT(
+ tokens,
+ ElementsAreArray({Token("問", 0, 1),
+ Token("少", 1, 2),
+ Token("目", 2, 3),
+ Token("hello", 4, 9),
+ Token("木", 10, 11),
+ Token("輸", 11, 12),
+ Token("ย", 12, 13),
+ Token("า", 13, 14),
+ Token("ม", 14, 15),
+ Token("き", 15, 16),
+ Token("ゃ", 16, 17)}));
+ // clang-format on
+}
+
+} // namespace
+} // namespace libtextclassifier2
diff --git a/types.h b/types.h
new file mode 100644
index 0000000..d50d438
--- /dev/null
+++ b/types.h
@@ -0,0 +1,164 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TYPES_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TYPES_H_
+
+#include <algorithm>
+#include <functional>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "util/base/logging.h"
+
+namespace libtextclassifier2 {
+
+constexpr int kInvalidIndex = -1;
+
+// Index for a 0-based array of tokens.
+using TokenIndex = int;
+
+// Index for a 0-based array of codepoints.
+using CodepointIndex = int;
+
+// Marks a span in a sequence of codepoints. The first element is the index of
+// the first codepoint of the span, and the second element is the index of the
+// codepoint one past the end of the span.
+// TODO(b/71982294): Make it a struct.
+using CodepointSpan = std::pair<CodepointIndex, CodepointIndex>;
+
+// Marks a span in a sequence of tokens. The first element is the index of the
+// first token in the span, and the second element is the index of the token one
+// past the end of the span.
+// TODO(b/71982294): Make it a struct.
+using TokenSpan = std::pair<TokenIndex, TokenIndex>;
+
+// Returns the size of the token span. Assumes that the span is valid.
+inline int TokenSpanSize(const TokenSpan& token_span) {
+ return token_span.second - token_span.first;
+}
+
+// Returns a token span consisting of one token.
+inline TokenSpan SingleTokenSpan(int token_index) {
+ return {token_index, token_index + 1};
+}
+
+// Returns an intersection of two token spans. Assumes that both spans are valid
+// and overlapping.
+inline TokenSpan IntersectTokenSpans(const TokenSpan& token_span1,
+ const TokenSpan& token_span2) {
+ return {std::max(token_span1.first, token_span2.first),
+ std::min(token_span1.second, token_span2.second)};
+}
+
+// Returns and expanded token span by adding a certain number of tokens on its
+// left and on its right.
+inline TokenSpan ExpandTokenSpan(const TokenSpan& token_span,
+ int num_tokens_left, int num_tokens_right) {
+ return {token_span.first - num_tokens_left,
+ token_span.second + num_tokens_right};
+}
+
+// Token holds a token, its position in the original string and whether it was
+// part of the input span.
+struct Token {
+ std::string value;
+ CodepointIndex start;
+ CodepointIndex end;
+
+ // Whether the token is a padding token.
+ bool is_padding;
+
+ // Default constructor constructs the padding-token.
+ Token()
+ : value(""), start(kInvalidIndex), end(kInvalidIndex), is_padding(true) {}
+
+ Token(const std::string& arg_value, CodepointIndex arg_start,
+ CodepointIndex arg_end)
+ : value(arg_value), start(arg_start), end(arg_end), is_padding(false) {}
+
+ bool operator==(const Token& other) const {
+ return value == other.value && start == other.start && end == other.end &&
+ is_padding == other.is_padding;
+ }
+
+ bool IsContainedInSpan(CodepointSpan span) const {
+ return start >= span.first && end <= span.second;
+ }
+};
+
+// Pretty-printing function for Token.
+inline logging::LoggingStringStream& operator<<(
+ logging::LoggingStringStream& stream, const Token& token) {
+ if (!token.is_padding) {
+ return stream << "Token(\"" << token.value << "\", " << token.start << ", "
+ << token.end << ")";
+ } else {
+ return stream << "Token()";
+ }
+}
+
+// Represents a result of Annotate call.
+struct AnnotatedSpan {
+ // Unicode codepoint indices in the input string.
+ CodepointSpan span = {kInvalidIndex, kInvalidIndex};
+
+ // Classification result for the span.
+ std::vector<std::pair<std::string, float>> classification;
+};
+
+// Pretty-printing function for AnnotatedSpan.
+inline logging::LoggingStringStream& operator<<(
+ logging::LoggingStringStream& stream, const AnnotatedSpan& span) {
+ std::string best_class;
+ float best_score = -1;
+ if (!span.classification.empty()) {
+ best_class = span.classification[0].first;
+ best_score = span.classification[0].second;
+ }
+ return stream << "Span(" << span.span.first << ", " << span.span.second
+ << ", " << best_class << ", " << best_score << ")";
+}
+
+// StringPiece analogue for std::vector<T>.
+template <class T>
+class VectorSpan {
+ public:
+ VectorSpan() : begin_(), end_() {}
+ VectorSpan(const std::vector<T>& v) // NOLINT(runtime/explicit)
+ : begin_(v.begin()), end_(v.end()) {}
+ VectorSpan(typename std::vector<T>::const_iterator begin,
+ typename std::vector<T>::const_iterator end)
+ : begin_(begin), end_(end) {}
+
+ const T& operator[](typename std::vector<T>::size_type i) const {
+ return *(begin_ + i);
+ }
+
+ int size() const { return end_ - begin_; }
+ typename std::vector<T>::const_iterator begin() const { return begin_; }
+ typename std::vector<T>::const_iterator end() const { return end_; }
+ const float* data() const { return &(*begin_); }
+
+ private:
+ typename std::vector<T>::const_iterator begin_;
+ typename std::vector<T>::const_iterator end_;
+};
+
+} // namespace libtextclassifier2
+
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_TYPES_H_
diff --git a/util/base/casts.h b/util/base/casts.h
index 805ee89..c33173a 100644
--- a/util/base/casts.h
+++ b/util/base/casts.h
@@ -14,12 +14,12 @@
* limitations under the License.
*/
-#ifndef LIBTEXTCLASSIFIER_UTIL_BASE_CASTS_H_
-#define LIBTEXTCLASSIFIER_UTIL_BASE_CASTS_H_
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_CASTS_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_CASTS_H_
#include <string.h> // for memcpy
-namespace libtextclassifier {
+namespace libtextclassifier2 {
// bit_cast<Dest, Source> is a template function that implements the equivalent
// of "*reinterpret_cast<Dest*>(&source)". We need this in very low-level
@@ -87,6 +87,6 @@
return dest;
}
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
-#endif // LIBTEXTCLASSIFIER_UTIL_BASE_CASTS_H_
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_CASTS_H_
diff --git a/util/base/config.h b/util/base/config.h
index e6c19a4..41b99a9 100644
--- a/util/base/config.h
+++ b/util/base/config.h
@@ -16,10 +16,10 @@
// Define macros to indicate C++ standard / platform / etc we use.
-#ifndef LIBTEXTCLASSIFIER_UTIL_BASE_CONFIG_H_
-#define LIBTEXTCLASSIFIER_UTIL_BASE_CONFIG_H_
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_CONFIG_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_CONFIG_H_
-namespace libtextclassifier {
+namespace libtextclassifier2 {
// Define LANG_CXX11 to 1 if current compiler supports C++11.
//
@@ -38,6 +38,6 @@
#define LANG_CXX11 1
#endif
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
-#endif // LIBTEXTCLASSIFIER_UTIL_BASE_CONFIG_H_
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_CONFIG_H_
diff --git a/util/base/endian.h b/util/base/endian.h
index 75f8bf7..2a6e654 100644
--- a/util/base/endian.h
+++ b/util/base/endian.h
@@ -14,12 +14,12 @@
* limitations under the License.
*/
-#ifndef LIBTEXTCLASSIFIER_UTIL_BASE_ENDIAN_H_
-#define LIBTEXTCLASSIFIER_UTIL_BASE_ENDIAN_H_
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_ENDIAN_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_ENDIAN_H_
#include "util/base/integral_types.h"
-namespace libtextclassifier {
+namespace libtextclassifier2 {
#if defined OS_LINUX || defined OS_CYGWIN || defined OS_ANDROID || \
defined(__ANDROID__)
@@ -40,7 +40,7 @@
// The following guarantees declaration of the byte swap functions, and
// defines __BYTE_ORDER for MSVC
-#if defined(__GLIBC__) || defined(__BIONIC__) || defined(__CYGWIN__)
+#if defined(__GLIBC__) || defined(__CYGWIN__)
#include <byteswap.h> // IWYU pragma: export
// The following section defines the byte swap functions for OS X / iOS,
// which does not ship with byteswap.h.
@@ -133,6 +133,6 @@
#endif /* ENDIAN */
};
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
-#endif // LIBTEXTCLASSIFIER_UTIL_BASE_ENDIAN_H_
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_ENDIAN_H_
diff --git a/util/base/integral_types.h b/util/base/integral_types.h
index 0322d33..a599f3c 100644
--- a/util/base/integral_types.h
+++ b/util/base/integral_types.h
@@ -16,12 +16,12 @@
// Basic integer type definitions.
-#ifndef LIBTEXTCLASSIFIER_UTIL_BASE_INTEGRAL_TYPES_H_
-#define LIBTEXTCLASSIFIER_UTIL_BASE_INTEGRAL_TYPES_H_
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_INTEGRAL_TYPES_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_INTEGRAL_TYPES_H_
#include "util/base/config.h"
-namespace libtextclassifier {
+namespace libtextclassifier2 {
typedef unsigned int uint32;
typedef unsigned long long uint64;
@@ -56,6 +56,6 @@
static_assert(sizeof(int64) == 8, "wrong size");
#endif // LANG_CXX11
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
-#endif // LIBTEXTCLASSIFIER_UTIL_BASE_INTEGRAL_TYPES_H_
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_INTEGRAL_TYPES_H_
diff --git a/util/base/logging.cc b/util/base/logging.cc
index 9de35ca..919bb36 100644
--- a/util/base/logging.cc
+++ b/util/base/logging.cc
@@ -22,7 +22,7 @@
#include "util/base/logging_raw.h"
-namespace libtextclassifier {
+namespace libtextclassifier2 {
namespace logging {
namespace {
@@ -57,12 +57,11 @@
}
LogMessage::~LogMessage() {
- const std::string message = stream_.str();
- LowLevelLogging(severity_, /* tag = */ "txtClsf", message);
+ LowLevelLogging(severity_, /* tag = */ "txtClsf", stream_.message);
if (severity_ == FATAL) {
exit(1);
}
}
} // namespace logging
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
diff --git a/util/base/logging.h b/util/base/logging.h
index dba0ed4..cebbbf2 100644
--- a/util/base/logging.h
+++ b/util/base/logging.h
@@ -14,18 +14,17 @@
* limitations under the License.
*/
-#ifndef LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_H_
-#define LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_H_
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_LOGGING_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_LOGGING_H_
#include <cassert>
-#include <sstream>
#include <string>
#include "util/base/logging_levels.h"
#include "util/base/port.h"
// TC_STRIP
-namespace libtextclassifier {
+namespace libtextclassifier2 {
// string class that can't be instantiated. Makes sure that the code does not
// compile when non std::string is used.
//
@@ -38,12 +37,49 @@
// Makes the class non-instantiable.
virtual ~string() = 0;
};
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
// TC_END_STRIP
-namespace libtextclassifier {
+namespace libtextclassifier2 {
namespace logging {
+// A tiny code footprint string stream for assembling log messages.
+struct LoggingStringStream {
+ LoggingStringStream() {}
+ LoggingStringStream &stream() { return *this; }
+ // Needed for invocation in TC_CHECK macro.
+ explicit operator bool() const { return true; }
+
+ std::string message;
+};
+
+template <typename T>
+inline LoggingStringStream &operator<<(LoggingStringStream &stream,
+ const T &entry) {
+ stream.message.append(std::to_string(entry));
+ return stream;
+}
+
+inline LoggingStringStream &operator<<(LoggingStringStream &stream,
+ const char *message) {
+ stream.message.append(message);
+ return stream;
+}
+
+#if defined(HAS_GLOBAL_STRING)
+inline LoggingStringStream &operator<<(LoggingStringStream &stream,
+ const ::string &message) {
+ stream.message.append(message);
+ return stream;
+}
+#endif
+
+inline LoggingStringStream &operator<<(LoggingStringStream &stream,
+ const std::string &message) {
+ stream.message.append(message);
+ return stream;
+}
+
// The class that does all the work behind our TC_LOG(severity) macros. Each
// TC_LOG(severity) << obj1 << obj2 << ...; logging statement creates a
// LogMessage temporary object containing a stringstream. Each operator<< adds
@@ -61,19 +97,34 @@
~LogMessage() TC_ATTRIBUTE_NOINLINE;
// Returns the stream associated with the logger object.
- std::stringstream &stream() { return stream_; }
+ LoggingStringStream &stream() { return stream_; }
private:
const LogSeverity severity_;
// Stream that "prints" all info into a string (not to a file). We construct
// here the entire logging message and next print it in one operation.
- std::stringstream stream_;
+ LoggingStringStream stream_;
};
-#define TC_LOG(severity) \
- ::libtextclassifier::logging::LogMessage( \
- ::libtextclassifier::logging::severity, __FILE__, __LINE__) \
+// Pseudo-stream that "eats" the tokens <<-pumped into it, without printing
+// anything.
+class NullStream {
+ public:
+ NullStream() {}
+ NullStream &stream() { return *this; }
+};
+template <typename T>
+inline NullStream &operator<<(NullStream &str, const T &) {
+ return str;
+}
+
+} // namespace logging
+} // namespace libtextclassifier2
+
+#define TC_LOG(severity) \
+ ::libtextclassifier2::logging::LogMessage( \
+ ::libtextclassifier2::logging::severity, __FILE__, __LINE__) \
.stream()
// If condition x is true, does nothing. Otherwise, crashes the program (liek
@@ -92,19 +143,7 @@
#define TC_CHECK_GE(x, y) TC_CHECK((x) >= (y))
#define TC_CHECK_NE(x, y) TC_CHECK((x) != (y))
-// Pseudo-stream that "eats" the tokens <<-pumped into it, without printing
-// anything.
-class NullStream {
- public:
- NullStream() {}
- NullStream &stream() { return *this; }
-};
-template <typename T>
-inline NullStream &operator<<(NullStream &str, const T &) {
- return str;
-}
-
-#define TC_NULLSTREAM ::libtextclassifier::logging::NullStream().stream()
+#define TC_NULLSTREAM ::libtextclassifier2::logging::NullStream().stream()
// Debug checks: a TC_DCHECK<suffix> macro should behave like TC_CHECK<suffix>
// in debug mode an don't check / don't print anything in non-debug mode.
@@ -133,15 +172,12 @@
#endif // NDEBUG
#ifdef LIBTEXTCLASSIFIER_VLOG
-#define TC_VLOG(severity) \
- ::libtextclassifier::logging::LogMessage(::libtextclassifier::logging::INFO, \
- __FILE__, __LINE__) \
+#define TC_VLOG(severity) \
+ ::libtextclassifier2::logging::LogMessage( \
+ ::libtextclassifier2::logging::INFO, __FILE__, __LINE__) \
.stream()
#else
#define TC_VLOG(severity) TC_NULLSTREAM
#endif
-} // namespace logging
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_H_
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_LOGGING_H_
diff --git a/util/base/logging_levels.h b/util/base/logging_levels.h
index d16f96a..7d7dff2 100644
--- a/util/base/logging_levels.h
+++ b/util/base/logging_levels.h
@@ -14,10 +14,10 @@
* limitations under the License.
*/
-#ifndef LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_LEVELS_H_
-#define LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_LEVELS_H_
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_LOGGING_LEVELS_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_LOGGING_LEVELS_H_
-namespace libtextclassifier {
+namespace libtextclassifier2 {
namespace logging {
enum LogSeverity {
@@ -28,6 +28,6 @@
};
} // namespace logging
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
-#endif // LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_LEVELS_H_
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_LOGGING_LEVELS_H_
diff --git a/util/base/logging_raw.cc b/util/base/logging_raw.cc
index 8e0eb1b..6d97852 100644
--- a/util/base/logging_raw.cc
+++ b/util/base/logging_raw.cc
@@ -26,7 +26,7 @@
// Compiled as part of Android.
#include <android/log.h>
-namespace libtextclassifier {
+namespace libtextclassifier2 {
namespace logging {
namespace {
@@ -60,12 +60,12 @@
}
} // namespace logging
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
#else // if defined(__ANDROID__)
// Not on Android: implement LowLevelLogging to print to stderr (see below).
-namespace libtextclassifier {
+namespace libtextclassifier2 {
namespace logging {
namespace {
@@ -94,6 +94,6 @@
}
} // namespace logging
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
#endif // if defined(__ANDROID__)
diff --git a/util/base/logging_raw.h b/util/base/logging_raw.h
index 40c2497..6cae105 100644
--- a/util/base/logging_raw.h
+++ b/util/base/logging_raw.h
@@ -14,14 +14,14 @@
* limitations under the License.
*/
-#ifndef LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_RAW_H_
-#define LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_RAW_H_
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_LOGGING_RAW_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_LOGGING_RAW_H_
#include <string>
#include "util/base/logging_levels.h"
-namespace libtextclassifier {
+namespace libtextclassifier2 {
namespace logging {
// Low-level logging primitive. Logs a message, with the indicated log
@@ -31,6 +31,6 @@
const std::string &message);
} // namespace logging
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
-#endif // LIBTEXTCLASSIFIER_UTIL_BASE_LOGGING_RAW_H_
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_LOGGING_RAW_H_
diff --git a/util/base/macros.h b/util/base/macros.h
index aec3a8a..7aca681 100644
--- a/util/base/macros.h
+++ b/util/base/macros.h
@@ -14,12 +14,12 @@
* limitations under the License.
*/
-#ifndef LIBTEXTCLASSIFIER_UTIL_BASE_MACROS_H_
-#define LIBTEXTCLASSIFIER_UTIL_BASE_MACROS_H_
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_MACROS_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_MACROS_H_
#include "util/base/config.h"
-namespace libtextclassifier {
+namespace libtextclassifier2 {
#if LANG_CXX11
#define TC_DISALLOW_COPY_AND_ASSIGN(TypeName) \
@@ -78,6 +78,6 @@
#define TC_FALLTHROUGH_INTENDED do { } while (0)
#endif
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
-#endif // LIBTEXTCLASSIFIER_UTIL_BASE_MACROS_H_
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_MACROS_H_
diff --git a/util/base/port.h b/util/base/port.h
index 394aaab..5a68daa 100644
--- a/util/base/port.h
+++ b/util/base/port.h
@@ -16,10 +16,10 @@
// Various portability macros, type definitions, and inline functions.
-#ifndef LIBTEXTCLASSIFIER_UTIL_BASE_PORT_H_
-#define LIBTEXTCLASSIFIER_UTIL_BASE_PORT_H_
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_PORT_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_PORT_H_
-namespace libtextclassifier {
+namespace libtextclassifier2 {
#if defined(__GNUC__) && \
(__GNUC__ > 3 || (__GNUC__ == 3 && __GNUC_MINOR__ >= 1))
@@ -40,6 +40,6 @@
#define TC_ATTRIBUTE_NOINLINE
#endif
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
-#endif // LIBTEXTCLASSIFIER_UTIL_BASE_PORT_H_
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_BASE_PORT_H_
diff --git a/util/gtl/map_util.h b/util/gtl/map_util.h
index b5eaafa..d14071e 100644
--- a/util/gtl/map_util.h
+++ b/util/gtl/map_util.h
@@ -14,10 +14,10 @@
* limitations under the License.
*/
-#ifndef LIBTEXTCLASSIFIER_UTIL_GTL_MAP_UTIL_H_
-#define LIBTEXTCLASSIFIER_UTIL_GTL_MAP_UTIL_H_
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_GTL_MAP_UTIL_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_GTL_MAP_UTIL_H_
-namespace libtextclassifier {
+namespace libtextclassifier2 {
// Returns a const reference to the value associated with the given key if it
// exists, otherwise returns a const reference to the provided default value.
@@ -60,6 +60,6 @@
typename Collection::value_type(key, value));
}
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
-#endif // LIBTEXTCLASSIFIER_UTIL_GTL_MAP_UTIL_H_
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_GTL_MAP_UTIL_H_
diff --git a/util/gtl/stl_util.h b/util/gtl/stl_util.h
index 8e1c452..9d93c03 100644
--- a/util/gtl/stl_util.h
+++ b/util/gtl/stl_util.h
@@ -14,10 +14,10 @@
* limitations under the License.
*/
-#ifndef LIBTEXTCLASSIFIER_UTIL_GTL_STL_UTIL_H_
-#define LIBTEXTCLASSIFIER_UTIL_GTL_STL_UTIL_H_
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_GTL_STL_UTIL_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_GTL_STL_UTIL_H_
-namespace libtextclassifier {
+namespace libtextclassifier2 {
// Deletes all the elements in an STL container and clears the container. This
// function is suitable for use with a vector, set, hash_set, or any other STL
@@ -50,6 +50,6 @@
container->clear();
}
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
-#endif // LIBTEXTCLASSIFIER_UTIL_GTL_STL_UTIL_H_
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_GTL_STL_UTIL_H_
diff --git a/util/hash/farmhash.h b/util/hash/farmhash.h
index 7adf3aa..3bbe294 100644
--- a/util/hash/farmhash.h
+++ b/util/hash/farmhash.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef LIBTEXTCLASSIFIER_UTIL_HASH_FARMHASH_H_
-#define LIBTEXTCLASSIFIER_UTIL_HASH_FARMHASH_H_
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_HASH_FARMHASH_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_HASH_FARMHASH_H_
#include <assert.h>
#include <stdint.h>
@@ -24,7 +24,7 @@
#include <utility>
#ifndef NAMESPACE_FOR_HASH_FUNCTIONS
-#define NAMESPACE_FOR_HASH_FUNCTIONS tcfarmhash
+#define NAMESPACE_FOR_HASH_FUNCTIONS tc2farmhash
#endif
namespace NAMESPACE_FOR_HASH_FUNCTIONS {
@@ -261,4 +261,4 @@
} // namespace NAMESPACE_FOR_HASH_FUNCTIONS
-#endif // LIBTEXTCLASSIFIER_UTIL_HASH_FARMHASH_H_
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_HASH_FARMHASH_H_
diff --git a/util/hash/hash.cc b/util/hash/hash.cc
index 1261417..9722ddc 100644
--- a/util/hash/hash.cc
+++ b/util/hash/hash.cc
@@ -18,7 +18,7 @@
#include "util/base/macros.h"
-namespace libtextclassifier {
+namespace libtextclassifier2 {
namespace {
// Lower-level versions of Get... that read directly from a character buffer
@@ -76,4 +76,4 @@
return h;
}
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
diff --git a/util/hash/hash.h b/util/hash/hash.h
index 0abb72b..beabd6e 100644
--- a/util/hash/hash.h
+++ b/util/hash/hash.h
@@ -14,14 +14,14 @@
* limitations under the License.
*/
-#ifndef LIBTEXTCLASSIFIER_UTIL_HASH_HASH_H_
-#define LIBTEXTCLASSIFIER_UTIL_HASH_HASH_H_
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_HASH_HASH_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_HASH_HASH_H_
#include <string>
#include "util/base/integral_types.h"
-namespace libtextclassifier {
+namespace libtextclassifier2 {
uint32 Hash32(const char *data, size_t n, uint32 seed);
@@ -33,6 +33,6 @@
return Hash32WithDefaultSeed(input.data(), input.size());
}
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
-#endif // LIBTEXTCLASSIFIER_UTIL_HASH_HASH_H_
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_HASH_HASH_H_
diff --git a/util/java/scoped_local_ref.h b/util/java/scoped_local_ref.h
index d995468..e716df5 100644
--- a/util/java/scoped_local_ref.h
+++ b/util/java/scoped_local_ref.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef LIBTEXTCLASSIFIER_UTIL_JAVA_SCOPED_LOCAL_REF_H_
-#define LIBTEXTCLASSIFIER_UTIL_JAVA_SCOPED_LOCAL_REF_H_
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_JAVA_SCOPED_LOCAL_REF_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_JAVA_SCOPED_LOCAL_REF_H_
#include <jni.h>
#include <memory>
@@ -23,7 +23,7 @@
#include "util/base/logging.h"
-namespace libtextclassifier {
+namespace libtextclassifier2 {
// A deleter to be used with std::unique_ptr to delete JNI local references.
class LocalRefDeleter {
@@ -60,6 +60,6 @@
using ScopedLocalRef =
std::unique_ptr<typename std::remove_pointer<T>::type, LocalRefDeleter>;
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
-#endif // LIBTEXTCLASSIFIER_UTIL_JAVA_SCOPED_LOCAL_REF_H_
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_JAVA_SCOPED_LOCAL_REF_H_
diff --git a/common/fastexp.cc b/util/math/fastexp.cc
similarity index 93%
rename from common/fastexp.cc
rename to util/math/fastexp.cc
index 0376ad2..4bf8592 100644
--- a/common/fastexp.cc
+++ b/util/math/fastexp.cc
@@ -14,10 +14,9 @@
* limitations under the License.
*/
-#include "common/fastexp.h"
+#include "util/math/fastexp.h"
-namespace libtextclassifier {
-namespace nlp_core {
+namespace libtextclassifier2 {
const int FastMathClass::kBits;
const int FastMathClass::kMask1;
@@ -46,5 +45,4 @@
7940441, 8029106, 8118253, 8207884, 8298001}
};
-} // namespace nlp_core
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
diff --git a/common/fastexp.h b/util/math/fastexp.h
similarity index 85%
rename from common/fastexp.h
rename to util/math/fastexp.h
index 1781b36..acc1453 100644
--- a/common/fastexp.h
+++ b/util/math/fastexp.h
@@ -16,8 +16,8 @@
// Fast approximation for exp.
-#ifndef LIBTEXTCLASSIFIER_COMMON_FASTEXP_H_
-#define LIBTEXTCLASSIFIER_COMMON_FASTEXP_H_
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_MATH_FASTEXP_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_MATH_FASTEXP_H_
#include <cassert>
#include <cmath>
@@ -27,8 +27,7 @@
#include "util/base/integral_types.h"
#include "util/base/logging.h"
-namespace libtextclassifier {
-namespace nlp_core {
+namespace libtextclassifier2 {
class FastMathClass {
private:
@@ -64,7 +63,6 @@
inline float VeryFastExp2(float f) { return FastMathInstance.VeryFastExp2(f); }
inline float VeryFastExp(float f) { return FastMathInstance.VeryFastExp(f); }
-} // namespace nlp_core
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
-#endif // LIBTEXTCLASSIFIER_COMMON_FASTEXP_H_
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_MATH_FASTEXP_H_
diff --git a/common/softmax.cc b/util/math/softmax.cc
similarity index 85%
rename from common/softmax.cc
rename to util/math/softmax.cc
index 3610de8..986787f 100644
--- a/common/softmax.cc
+++ b/util/math/softmax.cc
@@ -14,15 +14,14 @@
* limitations under the License.
*/
-#include "common/softmax.h"
+#include "util/math/softmax.h"
#include <limits>
-#include "common/fastexp.h"
#include "util/base/logging.h"
+#include "util/math/fastexp.h"
-namespace libtextclassifier {
-namespace nlp_core {
+namespace libtextclassifier2 {
float ComputeSoftmaxProbability(const std::vector<float> &scores, int label) {
if ((label < 0) || (label >= scores.size())) {
@@ -71,18 +70,24 @@
}
std::vector<float> ComputeSoftmax(const std::vector<float> &scores) {
+ return ComputeSoftmax(scores.data(), scores.size());
+}
+
+std::vector<float> ComputeSoftmax(const float *scores, int scores_size) {
std::vector<float> softmax;
std::vector<float> exp_scores;
- exp_scores.reserve(scores.size());
- softmax.reserve(scores.size());
+ exp_scores.reserve(scores_size);
+ softmax.reserve(scores_size);
// Find max value in "scores" vector and rescale to avoid overflows.
float max = std::numeric_limits<float>::min();
- for (const auto &score : scores) {
+ for (int i = 0; i < scores_size; ++i) {
+ const float score = scores[i];
if (score > max) max = score;
}
float denominator = 0;
- for (auto &score : scores) {
+ for (int i = 0; i < scores_size; ++i) {
+ const float score = scores[i];
// See comments above in ComputeSoftmaxProbability for the reasoning behind
// this approximation.
const float exp_score = score - max < -16.0f ? 0 : VeryFastExp(score - max);
@@ -90,11 +95,10 @@
denominator += exp_score;
}
- for (int i = 0; i < scores.size(); ++i) {
+ for (int i = 0; i < scores_size; ++i) {
softmax.push_back(exp_scores[i] / denominator);
}
return softmax;
}
-} // namespace nlp_core
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
diff --git a/common/softmax.h b/util/math/softmax.h
similarity index 71%
rename from common/softmax.h
rename to util/math/softmax.h
index e1cc2d9..57bf832 100644
--- a/common/softmax.h
+++ b/util/math/softmax.h
@@ -14,13 +14,12 @@
* limitations under the License.
*/
-#ifndef LIBTEXTCLASSIFIER_COMMON_SOFTMAX_H_
-#define LIBTEXTCLASSIFIER_COMMON_SOFTMAX_H_
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_MATH_SOFTMAX_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_MATH_SOFTMAX_H_
#include <vector>
-namespace libtextclassifier {
-namespace nlp_core {
+namespace libtextclassifier2 {
// Computes probability of a softmax label. Parameter "scores" is the vector of
// softmax logits. Returns 0.0f if "label" is outside the range [0,
@@ -31,7 +30,9 @@
// "scores" is the vector of softmax logits.
std::vector<float> ComputeSoftmax(const std::vector<float> &scores);
-} // namespace nlp_core
-} // namespace libtextclassifier
+// Same as above but operates on an array of floats.
+std::vector<float> ComputeSoftmax(const float *scores, int scores_size);
-#endif // LIBTEXTCLASSIFIER_COMMON_SOFTMAX_H_
+} // namespace libtextclassifier2
+
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_MATH_SOFTMAX_H_
diff --git a/common/mmap.cc b/util/memory/mmap.cc
similarity index 96%
rename from common/mmap.cc
rename to util/memory/mmap.cc
index 6e15a84..6b0bdf2 100644
--- a/common/mmap.cc
+++ b/util/memory/mmap.cc
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-#include "common/mmap.h"
+#include "util/memory/mmap.h"
#include <errno.h>
#include <fcntl.h>
@@ -27,8 +27,7 @@
#include "util/base/logging.h"
#include "util/base/macros.h"
-namespace libtextclassifier {
-namespace nlp_core {
+namespace libtextclassifier2 {
namespace {
inline std::string GetLastSystemError() { return std::string(strerror(errno)); }
@@ -133,5 +132,4 @@
return true;
}
-} // namespace nlp_core
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
diff --git a/common/mmap.h b/util/memory/mmap.h
similarity index 93%
rename from common/mmap.h
rename to util/memory/mmap.h
index 69f7b4c..781f222 100644
--- a/common/mmap.h
+++ b/util/memory/mmap.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef LIBTEXTCLASSIFIER_COMMON_MMAP_H_
-#define LIBTEXTCLASSIFIER_COMMON_MMAP_H_
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_MEMORY_MMAP_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_MEMORY_MMAP_H_
#include <stddef.h>
@@ -24,8 +24,7 @@
#include "util/base/integral_types.h"
#include "util/strings/stringpiece.h"
-namespace libtextclassifier {
-namespace nlp_core {
+namespace libtextclassifier2 {
// Handle for a memory area where a file has been mmapped.
//
@@ -137,7 +136,6 @@
MmapHandle handle_;
};
-} // namespace nlp_core
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
-#endif // LIBTEXTCLASSIFIER_COMMON_MMAP_H_
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_MEMORY_MMAP_H_
diff --git a/util/strings/numbers.cc b/util/strings/numbers.cc
index 4bd8b82..a89c0ef 100644
--- a/util/strings/numbers.cc
+++ b/util/strings/numbers.cc
@@ -22,7 +22,7 @@
#include <stdlib.h>
-namespace libtextclassifier {
+namespace libtextclassifier2 {
bool ParseInt32(const char *c_str, int32 *value) {
char *temp;
@@ -72,4 +72,4 @@
}
#endif // COMPILER_MSVC
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
diff --git a/util/strings/numbers.h b/util/strings/numbers.h
index eda53bf..096954e 100644
--- a/util/strings/numbers.h
+++ b/util/strings/numbers.h
@@ -14,14 +14,14 @@
* limitations under the License.
*/
-#ifndef LIBTEXTCLASSIFIER_UTIL_STRINGS_NUMBERS_H_
-#define LIBTEXTCLASSIFIER_UTIL_STRINGS_NUMBERS_H_
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_STRINGS_NUMBERS_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_STRINGS_NUMBERS_H_
#include <string>
#include "util/base/integral_types.h"
-namespace libtextclassifier {
+namespace libtextclassifier2 {
// Parses an int32 from a C-style string.
//
@@ -47,7 +47,6 @@
// int types.
std::string IntToString(int64 input);
+} // namespace libtextclassifier2
-} // namespace libtextclassifier
-
-#endif // LIBTEXTCLASSIFIER_UTIL_STRINGS_NUMBERS_H_
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_STRINGS_NUMBERS_H_
diff --git a/util/strings/numbers_test.cc b/util/strings/numbers_test.cc
index f3a3f27..1fdd78a 100644
--- a/util/strings/numbers_test.cc
+++ b/util/strings/numbers_test.cc
@@ -19,7 +19,7 @@
#include "util/base/integral_types.h"
#include "gtest/gtest.h"
-namespace libtextclassifier {
+namespace libtextclassifier2 {
namespace {
void TestParseInt32(const char *c_str, bool expected_parsing_success,
@@ -100,4 +100,4 @@
TestParseDouble("23.5a", false);
}
} // namespace
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
diff --git a/util/strings/split.cc b/util/strings/split.cc
index 8d250bb..e61e3ba 100644
--- a/util/strings/split.cc
+++ b/util/strings/split.cc
@@ -16,7 +16,7 @@
#include "util/strings/split.h"
-namespace libtextclassifier {
+namespace libtextclassifier2 {
namespace strings {
std::vector<std::string> Split(const std::string &text, char delim) {
@@ -35,4 +35,4 @@
}
} // namespace strings
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
diff --git a/util/strings/split.h b/util/strings/split.h
index b661ede0..9860265 100644
--- a/util/strings/split.h
+++ b/util/strings/split.h
@@ -14,18 +14,18 @@
* limitations under the License.
*/
-#ifndef LIBTEXTCLASSIFIER_UTIL_STRINGS_SPLIT_H_
-#define LIBTEXTCLASSIFIER_UTIL_STRINGS_SPLIT_H_
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_STRINGS_SPLIT_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_STRINGS_SPLIT_H_
#include <string>
#include <vector>
-namespace libtextclassifier {
+namespace libtextclassifier2 {
namespace strings {
std::vector<std::string> Split(const std::string &text, char delim);
} // namespace strings
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
-#endif // LIBTEXTCLASSIFIER_UTIL_STRINGS_SPLIT_H_
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_STRINGS_SPLIT_H_
diff --git a/util/strings/stringpiece.h b/util/strings/stringpiece.h
index 8c42d83..f6187e9 100644
--- a/util/strings/stringpiece.h
+++ b/util/strings/stringpiece.h
@@ -14,14 +14,14 @@
* limitations under the License.
*/
-#ifndef LIBTEXTCLASSIFIER_UTIL_STRINGS_STRINGPIECE_H_
-#define LIBTEXTCLASSIFIER_UTIL_STRINGS_STRINGPIECE_H_
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_STRINGS_STRINGPIECE_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_STRINGS_STRINGPIECE_H_
#include <stddef.h>
#include <string>
-namespace libtextclassifier {
+namespace libtextclassifier2 {
// Read-only "view" of a piece of data. Does not own the underlying data.
class StringPiece {
@@ -61,6 +61,6 @@
size_t size_;
};
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
-#endif // LIBTEXTCLASSIFIER_UTIL_STRINGS_STRINGPIECE_H_
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_STRINGS_STRINGPIECE_H_
diff --git a/util/strings/utf8.h b/util/strings/utf8.h
index 93c7fea..89823e2 100644
--- a/util/strings/utf8.h
+++ b/util/strings/utf8.h
@@ -14,10 +14,10 @@
* limitations under the License.
*/
-#ifndef LIBTEXTCLASSIFIER_UTIL_STRINGS_UTF8_H_
-#define LIBTEXTCLASSIFIER_UTIL_STRINGS_UTF8_H_
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_STRINGS_UTF8_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_STRINGS_UTF8_H_
-namespace libtextclassifier {
+namespace libtextclassifier2 {
// Returns the length (number of bytes) of the Unicode code point starting at
// src, based on inspecting just that one byte. Preconditions: src != NULL,
@@ -44,6 +44,6 @@
return static_cast<signed char>(x) < -0x40;
}
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
-#endif // LIBTEXTCLASSIFIER_UTIL_STRINGS_UTF8_H_
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_STRINGS_UTF8_H_
diff --git a/util/utf8/unicodetext.cc b/util/utf8/unicodetext.cc
index dbab1c8..c814a2e 100644
--- a/util/utf8/unicodetext.cc
+++ b/util/utf8/unicodetext.cc
@@ -22,7 +22,7 @@
#include "util/strings/utf8.h"
-namespace libtextclassifier {
+namespace libtextclassifier2 {
// *************** Data representation **********
// Note: the copy constructor is undefined.
@@ -109,6 +109,61 @@
return *this;
}
+namespace {
+
+enum {
+ RuneError = 0xFFFD, // Decoding error in UTF.
+ RuneMax = 0x10FFFF, // Maximum rune value.
+};
+
+int runetochar(const char32 rune, char* dest) {
+ // Convert to unsigned for range check.
+ uint32 c;
+
+ // 1 char 00-7F
+ c = rune;
+ if (c <= 0x7F) {
+ dest[0] = static_cast<char>(c);
+ return 1;
+ }
+
+ // 2 char 0080-07FF
+ if (c <= 0x07FF) {
+ dest[0] = 0xC0 | static_cast<char>(c >> 1 * 6);
+ dest[1] = 0x80 | (c & 0x3F);
+ return 2;
+ }
+
+ // Range check
+ if (c > RuneMax) {
+ c = RuneError;
+ }
+
+ // 3 char 0800-FFFF
+ if (c <= 0xFFFF) {
+ dest[0] = 0xE0 | static_cast<char>(c >> 2 * 6);
+ dest[1] = 0x80 | ((c >> 1 * 6) & 0x3F);
+ dest[2] = 0x80 | (c & 0x3F);
+ return 3;
+ }
+
+ // 4 char 10000-1FFFFF
+ dest[0] = 0xF0 | static_cast<char>(c >> 3 * 6);
+ dest[1] = 0x80 | ((c >> 2 * 6) & 0x3F);
+ dest[2] = 0x80 | ((c >> 1 * 6) & 0x3F);
+ dest[3] = 0x80 | (c & 0x3F);
+ return 4;
+}
+
+} // namespace
+
+UnicodeText& UnicodeText::AppendCodepoint(char32 ch) {
+ char str[4];
+ int char_len = runetochar(ch, str);
+ repr_.append(str, char_len);
+ return *this;
+}
+
void UnicodeText::clear() { repr_.clear(); }
int UnicodeText::size() const { return std::distance(begin(), end()); }
@@ -195,4 +250,4 @@
return UTF8ToUnicodeText(str.data(), str.size(), do_copy);
}
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
diff --git a/util/utf8/unicodetext.h b/util/utf8/unicodetext.h
index 6a21058..d331f9b 100644
--- a/util/utf8/unicodetext.h
+++ b/util/utf8/unicodetext.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef LIBTEXTCLASSIFIER_UTIL_UTF8_UNICODETEXT_H_
-#define LIBTEXTCLASSIFIER_UTIL_UTF8_UNICODETEXT_H_
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_UTF8_UNICODETEXT_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_UTF8_UNICODETEXT_H_
#include <iterator>
#include <string>
@@ -23,7 +23,7 @@
#include "util/base/integral_types.h"
-namespace libtextclassifier {
+namespace libtextclassifier2 {
// ***************************** UnicodeText **************************
//
@@ -150,6 +150,7 @@
// Calling this may invalidate pointers to underlying data.
UnicodeText& AppendUTF8(const char* utf8, int len);
+ UnicodeText& AppendCodepoint(char32 ch);
void clear();
static std::string UTF8Substring(const const_iterator& first,
@@ -193,6 +194,6 @@
UnicodeText UTF8ToUnicodeText(const char* utf8_buf, int len, bool do_copy);
UnicodeText UTF8ToUnicodeText(const std::string& str, bool do_copy);
-} // namespace libtextclassifier
+} // namespace libtextclassifier2
-#endif // LIBTEXTCLASSIFIER_UTIL_UTF8_UNICODETEXT_H_
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_UTF8_UNICODETEXT_H_
diff --git a/util/utf8/unilib-icu.cc b/util/utf8/unilib-icu.cc
new file mode 100644
index 0000000..147a364
--- /dev/null
+++ b/util/utf8/unilib-icu.cc
@@ -0,0 +1,109 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "util/utf8/unilib-icu.h"
+
+#include "util/base/logging.h"
+
+namespace libtextclassifier2 {
+
+bool UniLib::IsOpeningBracket(char32 codepoint) const {
+ return u_getIntPropertyValue(codepoint, UCHAR_BIDI_PAIRED_BRACKET_TYPE) ==
+ U_BPT_OPEN;
+}
+
+bool UniLib::IsClosingBracket(char32 codepoint) const {
+ return u_getIntPropertyValue(codepoint, UCHAR_BIDI_PAIRED_BRACKET_TYPE) ==
+ U_BPT_CLOSE;
+}
+
+bool UniLib::IsWhitespace(char32 codepoint) const {
+ return u_isWhitespace(codepoint);
+}
+
+bool UniLib::IsDigit(char32 codepoint) const { return u_isdigit(codepoint); }
+
+bool UniLib::IsUpper(char32 codepoint) const { return u_isupper(codepoint); }
+
+char32 UniLib::ToLower(char32 codepoint) const { return u_tolower(codepoint); }
+
+char32 UniLib::GetPairedBracket(char32 codepoint) const {
+ return u_getBidiPairedBracket(codepoint);
+}
+
+bool UniLib::RegexPattern::Matches(const std::string& text) {
+ const icu::UnicodeString unicode_text(text.c_str(), text.size(), "utf-8");
+ UErrorCode status;
+ status = U_ZERO_ERROR;
+ std::unique_ptr<icu::RegexMatcher> matcher(
+ pattern_->matcher(unicode_text, status));
+ if (U_FAILURE(status) || !matcher) {
+ return false;
+ }
+
+ status = U_ZERO_ERROR;
+ const bool result = matcher->matches(/*startIndex=*/0, status);
+ if (U_FAILURE(status)) {
+ return false;
+ }
+
+ return result;
+}
+
+constexpr int UniLib::BreakIterator::kDone;
+
+UniLib::BreakIterator::BreakIterator(const std::string& text) {
+ icu::ErrorCode status;
+ break_iterator_.reset(
+ icu::BreakIterator::createWordInstance(icu::Locale("en"), status));
+ if (!status.isSuccess()) {
+ break_iterator_.reset();
+ return;
+ }
+
+ const icu::UnicodeString unicode_text = icu::UnicodeString::fromUTF8(text);
+ break_iterator_->setText(unicode_text);
+}
+
+int UniLib::BreakIterator::Next() {
+ const int result = break_iterator_->next();
+ if (result == icu::BreakIterator::DONE) {
+ return BreakIterator::kDone;
+ } else {
+ return result;
+ }
+}
+
+std::unique_ptr<UniLib::RegexPattern> UniLib::CreateRegexPattern(
+ const std::string& regex) const {
+ UErrorCode status = U_ZERO_ERROR;
+ std::unique_ptr<icu::RegexPattern> pattern(icu::RegexPattern::compile(
+ icu::UnicodeString(regex.c_str(), regex.size(), "utf-8"), /*flags=*/0,
+ status));
+ if (U_FAILURE(status) || !pattern) {
+ return nullptr;
+ }
+ return std::unique_ptr<UniLib::RegexPattern>(
+ new UniLib::RegexPattern(std::move(pattern)));
+}
+
+std::unique_ptr<UniLib::BreakIterator> UniLib::CreateBreakIterator(
+ const std::string& text) const {
+ return std::unique_ptr<UniLib::BreakIterator>(
+ new UniLib::BreakIterator(text));
+}
+
+} // namespace libtextclassifier2
diff --git a/util/utf8/unilib-icu.h b/util/utf8/unilib-icu.h
new file mode 100644
index 0000000..0d34b74
--- /dev/null
+++ b/util/utf8/unilib-icu.h
@@ -0,0 +1,81 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// UniLib implementation with the help of ICU. UniLib is basically a wrapper
+// around the ICU functionality.
+
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_UTF8_UNILIB_ICU_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_UTF8_UNILIB_ICU_H_
+
+#include <memory>
+#include <string>
+
+#include "util/base/integral_types.h"
+#include "unicode/brkiter.h"
+#include "unicode/errorcode.h"
+#include "unicode/regex.h"
+#include "unicode/uchar.h"
+
+namespace libtextclassifier2 {
+
+class UniLib {
+ public:
+ bool IsOpeningBracket(char32 codepoint) const;
+ bool IsClosingBracket(char32 codepoint) const;
+ bool IsWhitespace(char32 codepoint) const;
+ bool IsDigit(char32 codepoint) const;
+ bool IsUpper(char32 codepoint) const;
+
+ char32 ToLower(char32 codepoint) const;
+ char32 GetPairedBracket(char32 codepoint) const;
+
+ class RegexPattern {
+ public:
+ // Returns true if the whole input matches with the regex.
+ bool Matches(const std::string& text);
+
+ protected:
+ friend class UniLib;
+ explicit RegexPattern(std::unique_ptr<icu::RegexPattern> pattern)
+ : pattern_(std::move(pattern)) {}
+
+ private:
+ std::unique_ptr<icu::RegexPattern> pattern_;
+ };
+
+ class BreakIterator {
+ public:
+ int Next();
+
+ static constexpr int kDone = -1;
+
+ protected:
+ friend class UniLib;
+ explicit BreakIterator(const std::string& text);
+
+ private:
+ std::unique_ptr<icu::BreakIterator> break_iterator_;
+ };
+
+ std::unique_ptr<RegexPattern> CreateRegexPattern(
+ const std::string& regex) const;
+ std::unique_ptr<BreakIterator> CreateBreakIterator(
+ const std::string& text) const;
+};
+
+} // namespace libtextclassifier2
+
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_UTF8_UNILIB_ICU_H_
diff --git a/util/utf8/unilib.h b/util/utf8/unilib.h
new file mode 100644
index 0000000..b583d72
--- /dev/null
+++ b/util/utf8/unilib.h
@@ -0,0 +1,28 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_UTF8_UNILIB_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_UTF8_UNILIB_H_
+
+#if defined LIBTEXTCLASSIFIER_UNILIB_ICU
+#include "util/utf8/unilib-icu.h"
+#elif defined LIBTEXTCLASSIFIER_UNILIB_DUMMY
+#include "util/utf8/unilib-dummy.h"
+#else
+#error No LIBTEXTCLASSIFIER_UNILIB implementation specified.
+#endif
+
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_UTIL_UTF8_UNILIB_H_
diff --git a/util/utf8/unilib_test.cc b/util/utf8/unilib_test.cc
new file mode 100644
index 0000000..a1bbdf4
--- /dev/null
+++ b/util/utf8/unilib_test.cc
@@ -0,0 +1,45 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "util/utf8/unilib.h"
+
+#include "util/base/logging.h"
+
+#include "gtest/gtest.h"
+
+namespace libtextclassifier2 {
+namespace {
+
+TEST(UniLibTest, Interface) {
+ UniLib unilib;
+ TC_LOG(INFO) << unilib.IsOpeningBracket('(');
+ TC_LOG(INFO) << unilib.IsClosingBracket(')');
+ TC_LOG(INFO) << unilib.IsWhitespace(')');
+ TC_LOG(INFO) << unilib.IsDigit(')');
+ TC_LOG(INFO) << unilib.IsUpper(')');
+ TC_LOG(INFO) << unilib.ToLower(')');
+ TC_LOG(INFO) << unilib.GetPairedBracket(')');
+ std::unique_ptr<UniLib::RegexPattern> pattern =
+ unilib.CreateRegexPattern("[0-9]");
+ TC_LOG(INFO) << pattern->Matches("Hello");
+ std::unique_ptr<UniLib::BreakIterator> iterator =
+ unilib.CreateBreakIterator("some text");
+ TC_LOG(INFO) << iterator->Next();
+ TC_LOG(INFO) << UniLib::BreakIterator::kDone;
+}
+
+} // namespace
+} // namespace libtextclassifier2