Merge branch 'temp_72223856' of persistent-https://android.git.corp.google.com/platform/external/libtextclassifier into merge1 am: d16be73edc
am: 9ac27ad018
Change-Id: I8f5ca294024401ec4924751acb30221578188351
diff --git a/Android.bp b/Android.bp
index 0ca3b5f..ca3521a 100644
--- a/Android.bp
+++ b/Android.bp
@@ -1,12 +1,25 @@
+// Copyright (C) 2017 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
cc_library_headers {
name: "libtextclassifier_hash_headers",
vendor_available: true,
export_include_dirs: ["."],
}
-cc_library_shared {
- name: "libtextclassifier_hash",
- vendor_available: true,
+cc_defaults {
+ name: "libtextclassifier_hash_defaults",
srcs: [
"util/hash/farmhash.cc",
"util/hash/hash.cc"
@@ -18,3 +31,16 @@
"-Wno-unused-function",
],
}
+
+cc_library_shared {
+ name: "libtextclassifier_hash",
+ defaults: ["libtextclassifier_hash_defaults"],
+ vendor_available: true,
+}
+
+cc_library_static {
+ name: "libtextclassifier_hash_static",
+ defaults: ["libtextclassifier_hash_defaults"],
+ sdk_version: "current",
+ stl: "libc++_static",
+}
diff --git a/Android.mk b/Android.mk
index b4ce617..2d19e3d 100644
--- a/Android.mk
+++ b/Android.mk
@@ -73,13 +73,14 @@
LOCAL_CFLAGS += $(MY_LIBTEXTCLASSIFIER_CFLAGS)
LOCAL_STRIP_MODULE := $(LIBTEXTCLASSIFIER_STRIP_OPTS)
-LOCAL_SRC_FILES := $(filter-out tests/%,$(call all-subdir-cpp-files))
+LOCAL_SRC_FILES := $(filter-out tests/% %_test.cc,$(call all-subdir-cpp-files))
LOCAL_C_INCLUDES += $(proto_sources_dir)/proto/external/libtextclassifier
LOCAL_STATIC_LIBRARIES += libtextclassifier_protos
LOCAL_SHARED_LIBRARIES += libprotobuf-cpp-lite
LOCAL_SHARED_LIBRARIES += liblog
LOCAL_SHARED_LIBRARIES += libicuuc libicui18n
+LOCAL_REQUIRED_MODULES := textclassifier.smartselection.en.model
LOCAL_ADDITIONAL_DEPENDENCIES += $(LOCAL_PATH)/jni.lds
LOCAL_LDFLAGS += -Wl,-version-script=$(LOCAL_PATH)/jni.lds
diff --git a/tests/embedding-feature-extractor_test.cc b/common/embedding-feature-extractor_test.cc
similarity index 100%
rename from tests/embedding-feature-extractor_test.cc
rename to common/embedding-feature-extractor_test.cc
diff --git a/tests/embedding-network_test.cc b/common/embedding-network_test.cc
similarity index 100%
rename from tests/embedding-network_test.cc
rename to common/embedding-network_test.cc
diff --git a/tests/fml-parser_test.cc b/common/fml-parser_test.cc
similarity index 100%
rename from tests/fml-parser_test.cc
rename to common/fml-parser_test.cc
diff --git a/common/little-endian-data.h b/common/little-endian-data.h
index 8441cd7..e3bc88f 100644
--- a/common/little-endian-data.h
+++ b/common/little-endian-data.h
@@ -21,7 +21,7 @@
#include <string>
#include <vector>
-#include "base.h"
+#include "util/base/endian.h"
#include "util/base/logging.h"
namespace libtextclassifier {
diff --git a/common/memory_image/data-store.cc b/common/memory_image/data-store.cc
index dcdaa82..a5f500c 100644
--- a/common/memory_image/data-store.cc
+++ b/common/memory_image/data-store.cc
@@ -29,12 +29,12 @@
}
}
-DataBlobView DataStore::GetData(const std::string &name) const {
+StringPiece DataStore::GetData(const std::string &name) const {
if (!reader_.success_status()) {
TC_LOG(ERROR) << "DataStore::GetData(" << name << ") "
<< "called on invalid DataStore; will return empty data "
<< "chunk";
- return DataBlobView();
+ return StringPiece();
}
const auto &entries = reader_.trimmed_proto().entries();
@@ -42,14 +42,14 @@
if (it == entries.end()) {
TC_LOG(ERROR) << "Unknown key: " << name
<< "; will return empty data chunk";
- return DataBlobView();
+ return StringPiece();
}
const DataStoreEntryBytes &entry_bytes = it->second;
if (!entry_bytes.has_blob_index()) {
TC_LOG(ERROR) << "DataStoreEntryBytes with no blob_index; "
<< "will return empty data chunk.";
- return DataBlobView();
+ return StringPiece();
}
int blob_index = entry_bytes.blob_index();
diff --git a/common/memory_image/data-store.h b/common/memory_image/data-store.h
index 9af5e0e..56aa4fc 100644
--- a/common/memory_image/data-store.h
+++ b/common/memory_image/data-store.h
@@ -20,7 +20,6 @@
#include <string>
#include "common/memory_image/data-store.pb.h"
-#include "common/memory_image/memory-image-common.h"
#include "common/memory_image/memory-image-reader.h"
#include "util/strings/stringpiece.h"
@@ -46,7 +45,7 @@
// If the alignment is a low power of 2 (e..g, 4, 8, or 16) and "start" passed
// to constructor corresponds to the beginning of a memory page or an address
// returned by new or malloc(), then start_addr is divisible with alignment.
- DataBlobView GetData(const std::string &name) const;
+ StringPiece GetData(const std::string &name) const;
private:
MemoryImageReader<DataStoreProto> reader_;
diff --git a/common/memory_image/embedding-network-params-from-image.h b/common/memory_image/embedding-network-params-from-image.h
index feb4817..e8c7d1e 100644
--- a/common/memory_image/embedding-network-params-from-image.h
+++ b/common/memory_image/embedding-network-params-from-image.h
@@ -22,6 +22,7 @@
#include "common/embedding-network.pb.h"
#include "common/memory_image/memory-image-reader.h"
#include "util/base/integral_types.h"
+#include "util/strings/stringpiece.h"
namespace libtextclassifier {
namespace nlp_core {
@@ -84,7 +85,7 @@
const int blob_index = trimmed_proto_.embeddings(i).is_quantized()
? (embeddings_blob_offset_ + 2 * i)
: (embeddings_blob_offset_ + i);
- DataBlobView data_blob_view = memory_reader_.data_blob_view(blob_index);
+ StringPiece data_blob_view = memory_reader_.data_blob_view(blob_index);
return data_blob_view.data();
}
@@ -104,7 +105,7 @@
// one blob with the quantized values and (immediately after it, hence the
// "+ 1") one blob with the scales.
int blob_index = embeddings_blob_offset_ + 2 * i + 1;
- DataBlobView data_blob_view = memory_reader_.data_blob_view(blob_index);
+ StringPiece data_blob_view = memory_reader_.data_blob_view(blob_index);
return reinterpret_cast<const float16 *>(data_blob_view.data());
} else {
return nullptr;
@@ -125,7 +126,7 @@
const void *hidden_weights(int i) const override {
TC_DCHECK(InRange(i, hidden_size()));
- DataBlobView data_blob_view =
+ StringPiece data_blob_view =
memory_reader_.data_blob_view(hidden_blob_offset_ + i);
return data_blob_view.data();
}
@@ -146,7 +147,7 @@
const void *hidden_bias_weights(int i) const override {
TC_DCHECK(InRange(i, hidden_bias_size()));
- DataBlobView data_blob_view =
+ StringPiece data_blob_view =
memory_reader_.data_blob_view(hidden_bias_blob_offset_ + i);
return data_blob_view.data();
}
@@ -167,7 +168,7 @@
const void *softmax_weights(int i) const override {
TC_DCHECK(InRange(i, softmax_size()));
- DataBlobView data_blob_view =
+ StringPiece data_blob_view =
memory_reader_.data_blob_view(softmax_blob_offset_ + i);
return data_blob_view.data();
}
@@ -188,7 +189,7 @@
const void *softmax_bias_weights(int i) const override {
TC_DCHECK(InRange(i, softmax_bias_size()));
- DataBlobView data_blob_view =
+ StringPiece data_blob_view =
memory_reader_.data_blob_view(softmax_bias_blob_offset_ + i);
return data_blob_view.data();
}
diff --git a/common/memory_image/in-memory-model-data.cc b/common/memory_image/in-memory-model-data.cc
index f5cda3a..acf3d86 100644
--- a/common/memory_image/in-memory-model-data.cc
+++ b/common/memory_image/in-memory-model-data.cc
@@ -17,7 +17,6 @@
#include "common/memory_image/in-memory-model-data.h"
#include "common/file-utils.h"
-#include "common/memory_image/memory-image-common.h"
#include "util/base/logging.h"
#include "util/strings/stringpiece.h"
@@ -28,14 +27,13 @@
const char InMemoryModelData::kFilePatternPrefix[] = "in-mem-model::";
bool InMemoryModelData::GetTaskSpec(TaskSpec *task_spec) const {
- DataBlobView blob = data_store_.GetData(kTaskSpecDataStoreEntryName);
+ StringPiece blob = data_store_.GetData(kTaskSpecDataStoreEntryName);
if (blob.data() == nullptr) {
TC_LOG(ERROR) << "Can't find data blob for TaskSpec, i.e., entry "
<< kTaskSpecDataStoreEntryName;
return false;
}
- bool parse_status = file_utils::ParseProtoFromMemory(
- blob.to_stringpiece(), task_spec);
+ bool parse_status = file_utils::ParseProtoFromMemory(blob, task_spec);
if (!parse_status) {
TC_LOG(ERROR) << "Error parsing TaskSpec";
return false;
@@ -43,13 +41,5 @@
return true;
}
-StringPiece InMemoryModelData::GetBytesForInputFile(
- const std::string &file_name) const {
- // TODO(salcianu): replace our DataBlobView with StringPiece everywhere.
- DataBlobView blob = data_store_.GetData(file_name);
- return StringPiece(reinterpret_cast<const char *>(blob.data()),
- blob.size());
-}
-
} // namespace nlp_core
} // namespace libtextclassifier
diff --git a/common/memory_image/in-memory-model-data.h b/common/memory_image/in-memory-model-data.h
index 5c26388..91e4436 100644
--- a/common/memory_image/in-memory-model-data.h
+++ b/common/memory_image/in-memory-model-data.h
@@ -62,7 +62,9 @@
// file_pattern for a TaskInput from the TaskSpec (see GetTaskSpec()).
// Returns a StringPiece indicating a memory area with the content bytes. On
// error, returns StringPiece(nullptr, 0).
- StringPiece GetBytesForInputFile(const std::string &file_name) const;
+ StringPiece GetBytesForInputFile(const std::string &file_name) const {
+ return data_store_.GetData(file_name);
+ }
private:
const memory_image::DataStore data_store_;
diff --git a/common/memory_image/low-level-memory-reader.h b/common/memory_image/low-level-memory-reader.h
index 91953d8..c87c772 100644
--- a/common/memory_image/low-level-memory-reader.h
+++ b/common/memory_image/low-level-memory-reader.h
@@ -21,10 +21,10 @@
#include <string>
-#include "base.h"
-#include "common/memory_image/memory-image-common.h"
+#include "util/base/endian.h"
#include "util/base/integral_types.h"
#include "util/base/logging.h"
+#include "util/strings/stringpiece.h"
namespace libtextclassifier {
namespace nlp_core {
@@ -61,7 +61,7 @@
// On success, sets *view to be a view of the relevant bytes: view.data()
// points to the beginning of the string bytes, and view.size() is the number
// of such bytes.
- bool ReadString(DataBlobView *view) {
+ bool ReadString(StringPiece *view) {
uint32 size;
if (!Read(&size, sizeof(size))) {
TC_LOG(ERROR) << "Unable to read std::string size";
@@ -73,15 +73,15 @@
<< " available < " << size << " required ";
return false;
}
- *view = DataBlobView(current_, size);
+ *view = StringPiece(current_, size);
Advance(size);
return true;
}
- // Like ReadString(DataBlobView *) but reads directly into a C++ string,
- // instead of a DataBlobView (StringPiece-like object).
+ // Like ReadString(StringPiece *) but reads directly into a C++ string,
+ // instead of a StringPiece (StringPiece-like object).
bool ReadString(std::string *target) {
- DataBlobView view;
+ StringPiece view;
if (!ReadString(&view)) {
return false;
}
diff --git a/common/memory_image/memory-image-common.h b/common/memory_image/memory-image-common.h
index 2e84116..3a46f49 100644
--- a/common/memory_image/memory-image-common.h
+++ b/common/memory_image/memory-image-common.h
@@ -35,38 +35,6 @@
static const int kDefaultAlignment;
};
-// Read-only "view" of a data blob. Does not own the underlying data; instead,
-// just a small object that points to an area of a memory image.
-//
-// TODO(salcianu): replace this class with StringPiece.
-class DataBlobView {
- public:
- DataBlobView() : DataBlobView(nullptr, 0) {}
-
- DataBlobView(const void *start, size_t size)
- : start_(start), size_(size) {}
-
- // Returns start address of a data blob from a memory image.
- const void *data() const { return start_; }
-
- // Returns number of bytes from the data blob starting at start().
- size_t size() const { return size_; }
-
- StringPiece to_stringpiece() const {
- return StringPiece(reinterpret_cast<const char *>(data()),
- size());
- }
-
- // Returns a std::string containing a copy of the data blob bytes.
- std::string ToString() const {
- return to_stringpiece().ToString();
- }
-
- private:
- const void *start_; // Not owned.
- size_t size_;
-};
-
} // namespace nlp_core
} // namespace libtextclassifier
diff --git a/common/memory_image/memory-image-reader.cc b/common/memory_image/memory-image-reader.cc
index c780412..7e717d5 100644
--- a/common/memory_image/memory-image-reader.cc
+++ b/common/memory_image/memory-image-reader.cc
@@ -18,10 +18,10 @@
#include <string>
-#include "base.h"
#include "common/memory_image/low-level-memory-reader.h"
#include "common/memory_image/memory-image-common.h"
#include "common/memory_image/memory-image.pb.h"
+#include "util/base/endian.h"
#include "util/base/logging.h"
namespace libtextclassifier {
diff --git a/common/memory_image/memory-image-reader.h b/common/memory_image/memory-image-reader.h
index 0a20805..c5954fd 100644
--- a/common/memory_image/memory-image-reader.h
+++ b/common/memory_image/memory-image-reader.h
@@ -22,11 +22,11 @@
#include <string>
#include <vector>
-#include "common/memory_image/memory-image-common.h"
#include "common/memory_image/memory-image.pb.h"
#include "util/base/integral_types.h"
#include "util/base/logging.h"
#include "util/base/macros.h"
+#include "util/strings/stringpiece.h"
namespace libtextclassifier {
namespace nlp_core {
@@ -66,11 +66,11 @@
}
// Returns pointer to the beginning of the data blob #i.
- DataBlobView data_blob_view(int i) const {
+ StringPiece data_blob_view(int i) const {
if ((i < 0) || (i >= num_data_blobs())) {
TC_LOG(ERROR) << "Blob index " << i << " outside range [0, "
<< num_data_blobs() << "); will return empty data chunk";
- return DataBlobView();
+ return StringPiece();
}
return data_blob_views_[i];
}
@@ -81,6 +81,12 @@
return trimmed_proto_serialization_.ToString();
}
+ // Same as above but returns the trimmed proto as a string piece pointing to
+ // the image.
+ StringPiece trimmed_proto_view() const {
+ return trimmed_proto_serialization_;
+ }
+
const MemoryImageHeader &header() { return header_; }
protected:
@@ -101,13 +107,13 @@
MemoryImageHeader header_;
// Binary serialization of the trimmed version of the original proto.
- // Represented as a DataBlobView backed up by the underlying memory image
+ // Represented as a StringPiece backed up by the underlying memory image
// bytes.
- DataBlobView trimmed_proto_serialization_;
+ StringPiece trimmed_proto_serialization_;
- // List of DataBlobView objects for all data blobs from the memory image (in
+ // List of StringPiece objects for all data blobs from the memory image (in
// order).
- std::vector<DataBlobView> data_blob_views_;
+ std::vector<StringPiece> data_blob_views_;
// Memory reading success status.
bool success_;
diff --git a/common/mmap.cc b/common/mmap.cc
index d652b3c..6e15a84 100644
--- a/common/mmap.cc
+++ b/common/mmap.cc
@@ -31,13 +31,9 @@
namespace nlp_core {
namespace {
-inline std::string GetLastSystemError() {
- return std::string(strerror(errno));
-}
+inline std::string GetLastSystemError() { return std::string(strerror(errno)); }
-inline MmapHandle GetErrorMmapHandle() {
- return MmapHandle(nullptr, 0);
-}
+inline MmapHandle GetErrorMmapHandle() { return MmapHandle(nullptr, 0); }
class FileCloser {
public:
@@ -49,11 +45,13 @@
TC_LOG(ERROR) << "Error closing file descriptor: " << last_error;
}
}
+
private:
const int fd_;
TC_DISALLOW_COPY_AND_ASSIGN(FileCloser);
};
+
} // namespace
MmapHandle MmapFile(const std::string &filename) {
@@ -81,7 +79,15 @@
TC_LOG(ERROR) << "Unable to stat fd: " << last_error;
return GetErrorMmapHandle();
}
- size_t file_size_in_bytes = static_cast<size_t>(sb.st_size);
+
+ return MmapFile(fd, /*segment_offset=*/0, /*segment_size=*/sb.st_size);
+}
+
+MmapHandle MmapFile(int fd, int64 segment_offset, int64 segment_size) {
+ static const int64 kPageSize = sysconf(_SC_PAGE_SIZE);
+ const int64 aligned_offset = (segment_offset / kPageSize) * kPageSize;
+ const int64 alignment_shift = segment_offset - aligned_offset;
+ const int64 aligned_length = segment_size + alignment_shift;
// Perform actual mmap.
void *mmap_addr = mmap(
@@ -89,8 +95,7 @@
// Let system pick address for mmapp-ed data.
nullptr,
- // Mmap all bytes from the file.
- file_size_in_bytes,
+ aligned_length,
// One can read / write the mapped data (but see MAP_PRIVATE below).
// Normally, we expect only to read it, but in the future, we may want to
@@ -104,16 +109,15 @@
// Descriptor of file to mmap.
fd,
- // Map bytes right from the beginning of the file. This, and
- // file_size_in_bytes (2nd argument) means we map all bytes from the file.
- 0);
+ aligned_offset);
if (mmap_addr == MAP_FAILED) {
const std::string last_error = GetLastSystemError();
- TC_LOG(ERROR) << "Error while mmaping: " << last_error;
+ TC_LOG(ERROR) << "Error while mmapping: " << last_error;
return GetErrorMmapHandle();
}
- return MmapHandle(mmap_addr, file_size_in_bytes);
+ return MmapHandle(static_cast<char *>(mmap_addr) + alignment_shift,
+ segment_size, /*unmap_addr=*/mmap_addr);
}
bool Unmap(MmapHandle mmap_handle) {
@@ -121,7 +125,7 @@
// Unmapping something that hasn't been mapped is trivially successful.
return true;
}
- if (munmap(mmap_handle.start(), mmap_handle.num_bytes()) != 0) {
+ if (munmap(mmap_handle.unmap_addr(), mmap_handle.num_bytes()) != 0) {
const std::string last_error = GetLastSystemError();
TC_LOG(ERROR) << "Error during Unmap / munmap: " << last_error;
return false;
diff --git a/common/mmap.h b/common/mmap.h
index ec79c9c..69f7b4c 100644
--- a/common/mmap.h
+++ b/common/mmap.h
@@ -21,6 +21,7 @@
#include <string>
+#include "util/base/integral_types.h"
#include "util/strings/stringpiece.h"
namespace libtextclassifier {
@@ -39,12 +40,22 @@
// are ok keeping that file in memory the whole time).
class MmapHandle {
public:
- MmapHandle(void *start, size_t num_bytes)
- : start_(start), num_bytes_(num_bytes) {}
+ MmapHandle(void *start, size_t num_bytes, void *unmap_addr = nullptr)
+ : start_(start), num_bytes_(num_bytes), unmap_addr_(unmap_addr) {}
// Returns start address for the memory area where a file has been mmapped.
void *start() const { return start_; }
+ // Returns address to use for munmap call. If unmap_addr was not specified
+ // the start address is used.
+ void *unmap_addr() const {
+ if (unmap_addr_ != nullptr) {
+ return unmap_addr_;
+ } else {
+ return start_;
+ }
+ }
+
// Returns number of bytes of the memory area from start().
size_t num_bytes() const { return num_bytes_; }
@@ -63,6 +74,9 @@
// See doc for num_bytes().
const size_t num_bytes_;
+
+ // Address to use for unmapping.
+ void *const unmap_addr_;
};
// Maps the full content of a file in memory (using mmap).
@@ -88,6 +102,13 @@
// Like MmapFile(const std::string &filename), but uses a file descriptor.
MmapHandle MmapFile(int fd);
+// Maps a segment of a file to memory. File is given by a file descriptor, and
+// offset (relative to the beginning of the file) and size specify the segment
+// to be mapped. NOTE: Internally, we align the offset for the call to mmap
+// system call to be a multiple of page size, so offset does NOT have to be a
+// multiply of the page size.
+MmapHandle MmapFile(int fd, int64 segment_offset, int64 segment_size);
+
// Unmaps a file mapped using MmapFile. Returns true on success, false
// otherwise.
bool Unmap(MmapHandle mmap_handle);
@@ -99,8 +120,10 @@
explicit ScopedMmap(const std::string &filename)
: handle_(MmapFile(filename)) {}
- explicit ScopedMmap(int fd)
- : handle_(MmapFile(fd)) {}
+ explicit ScopedMmap(int fd) : handle_(MmapFile(fd)) {}
+
+ ScopedMmap(int fd, int segment_offset, int segment_size)
+ : handle_(MmapFile(fd, segment_offset, segment_size)) {}
~ScopedMmap() {
if (handle_.ok()) {
diff --git a/tests/functions.cc b/common/mock_functions.cc
similarity index 95%
rename from tests/functions.cc
rename to common/mock_functions.cc
index 8ea5a8d..c661b70 100644
--- a/tests/functions.cc
+++ b/common/mock_functions.cc
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-#include "tests/functions.h"
+#include "common/mock_functions.h"
#include "common/registry.h"
diff --git a/tests/functions.h b/common/mock_functions.h
similarity index 91%
rename from tests/functions.h
rename to common/mock_functions.h
index b96fe2d..b5bcb07 100644
--- a/tests/functions.h
+++ b/common/mock_functions.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef LIBTEXTCLASSIFIER_TESTS_FUNCTIONS_H_
-#define LIBTEXTCLASSIFIER_TESTS_FUNCTIONS_H_
+#ifndef LIBTEXTCLASSIFIER_COMMON_MOCK_FUNCTIONS_H_
+#define LIBTEXTCLASSIFIER_COMMON_MOCK_FUNCTIONS_H_
#include <math.h>
@@ -24,6 +24,7 @@
namespace libtextclassifier {
namespace nlp_core {
namespace functions {
+
// Abstract double -> double function.
class Function : public RegisterableClass<Function> {
public:
@@ -61,6 +62,7 @@
int Evaluate(int k) override { return k + 1; }
TC_DEFINE_REGISTRATION_METHOD("dec", Dec);
};
+
} // namespace functions
// Should be inside namespace libtextclassifier::nlp_core.
@@ -70,4 +72,4 @@
} // namespace nlp_core
} // namespace libtextclassifier
-#endif // LIBTEXTCLASSIFIER_TESTS_FUNCTIONS_H_
+#endif // LIBTEXTCLASSIFIER_COMMON_MOCK_FUNCTIONS_H_
diff --git a/tests/registry_test.cc b/common/registry_test.cc
similarity index 98%
rename from tests/registry_test.cc
rename to common/registry_test.cc
index 7de4163..d5d7006 100644
--- a/tests/registry_test.cc
+++ b/common/registry_test.cc
@@ -16,7 +16,7 @@
#include <memory>
-#include "tests/functions.h"
+#include "common/mock_functions.h"
#include "gtest/gtest.h"
namespace libtextclassifier {
diff --git a/tests/lang-id_test.cc b/lang_id/lang-id_test.cc
similarity index 99%
rename from tests/lang-id_test.cc
rename to lang_id/lang-id_test.cc
index 3a735c2..2f8aedd 100644
--- a/tests/lang-id_test.cc
+++ b/lang_id/lang-id_test.cc
@@ -21,7 +21,6 @@
#include <utility>
#include <vector>
-#include "base.h"
#include "util/base/logging.h"
#include "gtest/gtest.h"
diff --git a/models/textclassifier.smartselection.en.model b/models/textclassifier.smartselection.en.model
index 315e2b4..7af0897 100644
--- a/models/textclassifier.smartselection.en.model
+++ b/models/textclassifier.smartselection.en.model
Binary files differ
diff --git a/smartselect/cached-features.h b/smartselect/cached-features.h
index 6490748..990233c 100644
--- a/smartselect/cached-features.h
+++ b/smartselect/cached-features.h
@@ -20,7 +20,6 @@
#include <memory>
#include <vector>
-#include "base.h"
#include "common/vector-span.h"
#include "smartselect/types.h"
diff --git a/tests/cached-features_test.cc b/smartselect/cached-features_test.cc
similarity index 100%
rename from tests/cached-features_test.cc
rename to smartselect/cached-features_test.cc
diff --git a/smartselect/feature-processor.cc b/smartselect/feature-processor.cc
index 1b15982..c1db95a 100644
--- a/smartselect/feature-processor.cc
+++ b/smartselect/feature-processor.cc
@@ -24,9 +24,11 @@
#include "util/base/logging.h"
#include "util/strings/utf8.h"
#include "util/utf8/unicodetext.h"
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
#include "unicode/brkiter.h"
#include "unicode/errorcode.h"
#include "unicode/uchar.h"
+#endif
namespace libtextclassifier {
@@ -51,6 +53,10 @@
extractor_options.remap_digits = options.remap_digits();
extractor_options.lowercase_tokens = options.lowercase_tokens();
+ for (const auto& chargram : options.allowed_chargrams()) {
+ extractor_options.allowed_chargrams.insert(chargram);
+ }
+
return extractor_options;
}
@@ -113,34 +119,14 @@
}
}
-void FindSubstrings(const UnicodeText& t, const std::set<char32>& codepoints,
- std::vector<UnicodeTextRange>* ranges) {
- UnicodeText::const_iterator start = t.begin();
- UnicodeText::const_iterator curr = start;
- UnicodeText::const_iterator end = t.end();
- for (; curr != end; ++curr) {
- if (codepoints.find(*curr) != codepoints.end()) {
- if (start != curr) {
- ranges->push_back(std::make_pair(start, curr));
- }
- start = curr;
- ++start;
- }
- }
- if (start != end) {
- ranges->push_back(std::make_pair(start, end));
- }
-}
+} // namespace internal
-void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
- std::vector<Token>* tokens) {
+void FeatureProcessor::StripTokensFromOtherLines(
+ const std::string& context, CodepointSpan span,
+ std::vector<Token>* tokens) const {
const UnicodeText context_unicode = UTF8ToUnicodeText(context,
/*do_copy=*/false);
- std::vector<UnicodeTextRange> lines;
- std::set<char32> codepoints;
- codepoints.insert('\n');
- codepoints.insert('|');
- internal::FindSubstrings(context_unicode, codepoints, &lines);
+ std::vector<UnicodeTextRange> lines = SplitContext(context_unicode);
auto span_start = context_unicode.begin();
if (span.first > 0) {
@@ -170,11 +156,11 @@
}
}
-} // namespace internal
-
std::string FeatureProcessor::GetDefaultCollection() const {
- if (options_.default_collection() >= options_.collections_size()) {
- TC_LOG(ERROR) << "No collections specified. Returning empty string.";
+ if (options_.default_collection() < 0 ||
+ options_.default_collection() >= options_.collections_size()) {
+ TC_LOG(ERROR)
+ << "Invalid or missing default collection. Returning empty string.";
return "";
}
return options_.collections(options_.default_collection());
@@ -217,18 +203,38 @@
return false;
}
- const int result_begin_token = token_span.first;
- const int result_begin_codepoint =
- tokens[options_.context_size() - result_begin_token].start;
- const int result_end_token = token_span.second;
- const int result_end_codepoint =
- tokens[options_.context_size() + result_end_token].end;
+ const int result_begin_token_index = token_span.first;
+ const Token& result_begin_token =
+ tokens[options_.context_size() - result_begin_token_index];
+ const int result_begin_codepoint = result_begin_token.start;
+ const int result_end_token_index = token_span.second;
+ const Token& result_end_token =
+ tokens[options_.context_size() + result_end_token_index];
+ const int result_end_codepoint = result_end_token.end;
if (result_begin_codepoint == kInvalidIndex ||
result_end_codepoint == kInvalidIndex) {
*span = CodepointSpan({kInvalidIndex, kInvalidIndex});
} else {
- *span = CodepointSpan({result_begin_codepoint, result_end_codepoint});
+ const UnicodeText token_begin_unicode =
+ UTF8ToUnicodeText(result_begin_token.value, /*do_copy=*/false);
+ UnicodeText::const_iterator token_begin = token_begin_unicode.begin();
+ const UnicodeText token_end_unicode =
+ UTF8ToUnicodeText(result_end_token.value, /*do_copy=*/false);
+ UnicodeText::const_iterator token_end = token_end_unicode.end();
+
+ const int begin_ignored = CountIgnoredSpanBoundaryCodepoints(
+ token_begin, token_begin_unicode.end(), /*count_from_beginning=*/true);
+ const int end_ignored = CountIgnoredSpanBoundaryCodepoints(
+ token_end_unicode.begin(), token_end, /*count_from_beginning=*/false);
+ // In case everything would be stripped, set the span to the original
+ // beginning and zero length.
+ if (begin_ignored == (result_end_codepoint - result_begin_codepoint)) {
+ *span = {result_begin_codepoint, result_begin_codepoint};
+ } else {
+ *span = CodepointSpan({result_begin_codepoint + begin_ignored,
+ result_end_codepoint - end_ignored});
+ }
}
return true;
}
@@ -274,14 +280,28 @@
// Check that the spanned tokens cover the whole span.
bool tokens_match_span;
+ const CodepointIndex tokens_start = tokens[click_position - span_left].start;
+ const CodepointIndex tokens_end = tokens[click_position + span_right].end;
if (options_.snap_label_span_boundaries_to_containing_tokens()) {
- tokens_match_span =
- tokens[click_position - span_left].start <= span.first &&
- tokens[click_position + span_right].end >= span.second;
+ tokens_match_span = tokens_start <= span.first && tokens_end >= span.second;
} else {
- tokens_match_span =
- tokens[click_position - span_left].start == span.first &&
- tokens[click_position + span_right].end == span.second;
+ const UnicodeText token_left_unicode = UTF8ToUnicodeText(
+ tokens[click_position - span_left].value, /*do_copy=*/false);
+ const UnicodeText token_right_unicode = UTF8ToUnicodeText(
+ tokens[click_position + span_right].value, /*do_copy=*/false);
+
+ UnicodeText::const_iterator span_begin = token_left_unicode.begin();
+ UnicodeText::const_iterator span_end = token_right_unicode.end();
+
+ const int num_punctuation_start = CountIgnoredSpanBoundaryCodepoints(
+ span_begin, token_left_unicode.end(), /*count_from_beginning=*/true);
+ const int num_punctuation_end = CountIgnoredSpanBoundaryCodepoints(
+ token_right_unicode.begin(), span_end, /*count_from_beginning=*/false);
+
+ tokens_match_span = tokens_start <= span.first &&
+ tokens_start + num_punctuation_start >= span.first &&
+ tokens_end >= span.second &&
+ tokens_end - num_punctuation_end <= span.second;
}
if (tokens_match_span) {
@@ -303,16 +323,23 @@
}
TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
- CodepointSpan codepoint_span) {
+ CodepointSpan codepoint_span,
+ bool snap_boundaries_to_containing_tokens) {
const int codepoint_start = std::get<0>(codepoint_span);
const int codepoint_end = std::get<1>(codepoint_span);
TokenIndex start_token = kInvalidIndex;
TokenIndex end_token = kInvalidIndex;
for (int i = 0; i < selectable_tokens.size(); ++i) {
- if (codepoint_start <= selectable_tokens[i].start &&
- codepoint_end >= selectable_tokens[i].end &&
- !selectable_tokens[i].is_padding) {
+ bool is_token_in_span;
+ if (snap_boundaries_to_containing_tokens) {
+ is_token_in_span = codepoint_start < selectable_tokens[i].end &&
+ codepoint_end > selectable_tokens[i].start;
+ } else {
+ is_token_in_span = codepoint_start <= selectable_tokens[i].start &&
+ codepoint_end >= selectable_tokens[i].end;
+ }
+ if (is_token_in_span && !selectable_tokens[i].is_padding) {
if (start_token == kInvalidIndex) {
start_token = i;
}
@@ -453,6 +480,114 @@
});
}
+void FeatureProcessor::PrepareIgnoredSpanBoundaryCodepoints() {
+ for (const int codepoint : options_.ignored_span_boundary_codepoints()) {
+ ignored_span_boundary_codepoints_.insert(codepoint);
+ }
+}
+
+int FeatureProcessor::CountIgnoredSpanBoundaryCodepoints(
+ const UnicodeText::const_iterator& span_start,
+ const UnicodeText::const_iterator& span_end,
+ bool count_from_beginning) const {
+ if (span_start == span_end) {
+ return 0;
+ }
+
+ UnicodeText::const_iterator it;
+ UnicodeText::const_iterator it_last;
+ if (count_from_beginning) {
+ it = span_start;
+ it_last = span_end;
+ // We can assume that the string is non-zero length because of the check
+ // above, thus the decrement is always valid here.
+ --it_last;
+ } else {
+ it = span_end;
+ it_last = span_start;
+ // We can assume that the string is non-zero length because of the check
+ // above, thus the decrement is always valid here.
+ --it;
+ }
+
+ // Move until we encounter a non-ignored character.
+ int num_ignored = 0;
+ while (ignored_span_boundary_codepoints_.find(*it) !=
+ ignored_span_boundary_codepoints_.end()) {
+ ++num_ignored;
+
+ if (it == it_last) {
+ break;
+ }
+
+ if (count_from_beginning) {
+ ++it;
+ } else {
+ --it;
+ }
+ }
+
+ return num_ignored;
+}
+
+namespace {
+
+void FindSubstrings(const UnicodeText& t, const std::set<char32>& codepoints,
+ std::vector<UnicodeTextRange>* ranges) {
+ UnicodeText::const_iterator start = t.begin();
+ UnicodeText::const_iterator curr = start;
+ UnicodeText::const_iterator end = t.end();
+ for (; curr != end; ++curr) {
+ if (codepoints.find(*curr) != codepoints.end()) {
+ if (start != curr) {
+ ranges->push_back(std::make_pair(start, curr));
+ }
+ start = curr;
+ ++start;
+ }
+ }
+ if (start != end) {
+ ranges->push_back(std::make_pair(start, end));
+ }
+}
+
+} // namespace
+
+std::vector<UnicodeTextRange> FeatureProcessor::SplitContext(
+ const UnicodeText& context_unicode) const {
+ if (options_.only_use_line_with_click()) {
+ std::vector<UnicodeTextRange> lines;
+ std::set<char32> codepoints;
+ codepoints.insert('\n');
+ codepoints.insert('|');
+ FindSubstrings(context_unicode, codepoints, &lines);
+ return lines;
+ } else {
+ return {{context_unicode.begin(), context_unicode.end()}};
+ }
+}
+
+CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
+ const std::string& context, CodepointSpan span) const {
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+ UnicodeText::const_iterator span_begin = context_unicode.begin();
+ std::advance(span_begin, span.first);
+ UnicodeText::const_iterator span_end = context_unicode.begin();
+ std::advance(span_end, span.second);
+
+ const int start_offset = CountIgnoredSpanBoundaryCodepoints(
+ span_begin, span_end, /*count_from_beginning=*/true);
+ const int end_offset = CountIgnoredSpanBoundaryCodepoints(
+ span_begin, span_end, /*count_from_beginning=*/false);
+
+ if (span.first + start_offset < span.second - end_offset) {
+ return {span.first + start_offset, span.second - end_offset};
+ } else {
+ return {span.first, span.first};
+ }
+}
+
float FeatureProcessor::SupportedCodepointsRatio(
int click_pos, const std::vector<Token>& tokens) const {
int num_supported = 0;
@@ -550,7 +685,7 @@
}
if (options_.only_use_line_with_click()) {
- internal::StripTokensFromOtherLines(context, input_span, tokens);
+ StripTokensFromOtherLines(context, input_span, tokens);
}
int local_click_pos;
@@ -605,13 +740,20 @@
std::unique_ptr<CachedFeatures>* cached_features) const {
TokenizeAndFindClick(context, input_span, tokens, click_pos);
- // If the default click method failed, let's try to do sub-token matching
- // before we fail.
- if (*click_pos == kInvalidIndex) {
- *click_pos = internal::CenterTokenFromClick(input_span, *tokens);
+ if (input_span.first != kInvalidIndex && input_span.second != kInvalidIndex) {
+ // If the default click method failed, let's try to do sub-token matching
+ // before we fail.
if (*click_pos == kInvalidIndex) {
- return false;
+ *click_pos = internal::CenterTokenFromClick(input_span, *tokens);
+ if (*click_pos == kInvalidIndex) {
+ return false;
+ }
}
+ } else {
+ // If input_span is unspecified, click the first token and extract features
+ // from all tokens.
+ *click_pos = 0;
+ relative_click_span = {0, tokens->size()};
}
internal::StripOrPadTokens(relative_click_span, options_.context_size(),
@@ -621,8 +763,8 @@
const float supported_codepoint_ratio =
SupportedCodepointsRatio(*click_pos, *tokens);
if (supported_codepoint_ratio < options_.min_supported_codepoint_ratio()) {
- TC_LOG(INFO) << "Not enough supported codepoints in the context: "
- << supported_codepoint_ratio;
+ TC_VLOG(1) << "Not enough supported codepoints in the context: "
+ << supported_codepoint_ratio;
return false;
}
}
@@ -658,6 +800,7 @@
bool FeatureProcessor::ICUTokenize(const std::string& context,
std::vector<Token>* result) const {
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
icu::ErrorCode status;
icu::UnicodeString unicode_text = icu::UnicodeString::fromUTF8(context);
std::unique_ptr<icu::BreakIterator> break_iterator(
@@ -699,6 +842,10 @@
}
return true;
+#else
+ TC_LOG(WARNING) << "Can't tokenize, ICU not supported";
+ return false;
+#endif
}
void FeatureProcessor::InternalRetokenize(const std::string& context,
@@ -758,6 +905,8 @@
// Run the tokenizer and update the token bounds to reflect the offset of the
// substring.
std::vector<Token> tokens = tokenizer_.Tokenize(text);
+ // Avoids progressive capacity increases in the for loop.
+ result->reserve(result->size() + tokens.size());
for (Token& token : tokens) {
token.start += span.first;
token.end += span.first;
diff --git a/smartselect/feature-processor.h b/smartselect/feature-processor.h
index 2c64b67..ef9a3df 100644
--- a/smartselect/feature-processor.h
+++ b/smartselect/feature-processor.h
@@ -20,6 +20,7 @@
#define LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_
#include <memory>
+#include <set>
#include <string>
#include <vector>
@@ -52,11 +53,6 @@
TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
const FeatureProcessorOptions& options);
-// Removes tokens that are not part of a line of the context which contains
-// given span.
-void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
- std::vector<Token>* tokens);
-
// Splits tokens that contain the selection boundary inside them.
// E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
@@ -80,8 +76,12 @@
} // namespace internal
// Converts a codepoint span to a token span in the given list of tokens.
-TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
- CodepointSpan codepoint_span);
+// If snap_boundaries_to_containing_tokens is set to true, it is enough for a
+// token to overlap with the codepoint range to be considered part of it.
+// Otherwise it must be fully included in the range.
+TokenSpan CodepointSpanToTokenSpan(
+ const std::vector<Token>& selectable_tokens, CodepointSpan codepoint_span,
+ bool snap_boundaries_to_containing_tokens = false);
// Converts a token span to a codepoint span in the given list of tokens.
CodepointSpan TokenSpanToCodepointSpan(
@@ -104,6 +104,7 @@
{options.internal_tokenizer_codepoint_ranges().begin(),
options.internal_tokenizer_codepoint_ranges().end()},
&internal_tokenizer_codepoint_ranges_);
+ PrepareIgnoredSpanBoundaryCodepoints();
}
explicit FeatureProcessor(const std::string& serialized_options)
@@ -137,6 +138,8 @@
// Extracts features as a CachedFeatures object that can be used for repeated
// inference over token spans in the given context.
+ // When input_span == {kInvalidIndex, kInvalidIndex} then, relative_click_span
+ // is ignored, and all tokens extracted from context will be considered.
bool ExtractFeatures(const std::string& context, CodepointSpan input_span,
TokenSpan relative_click_span,
const FeatureVectorFn& feature_vector_fn,
@@ -155,6 +158,16 @@
return feature_extractor_.DenseFeaturesCount();
}
+ // Splits context to several segments according to configuration.
+ std::vector<UnicodeTextRange> SplitContext(
+ const UnicodeText& context_unicode) const;
+
+ // Strips boundary codepoints from the span in context and returns the new
+ // start and end indices. If the span comprises entirely of boundary
+ // codepoints, the first index of span is returned for both indices.
+ CodepointSpan StripBoundaryCodepoints(const std::string& context,
+ CodepointSpan span) const;
+
protected:
// Represents a codepoint range [start, end).
struct CodepointRange {
@@ -207,6 +220,18 @@
bool IsCodepointInRanges(
int codepoint, const std::vector<CodepointRange>& codepoint_ranges) const;
+ void PrepareIgnoredSpanBoundaryCodepoints();
+
+ // Counts the number of span boundary codepoints. If count_from_beginning is
+ // True, the counting will start at the span_start iterator (inclusive) and at
+ // maximum end at span_end (exclusive). If count_from_beginning is True, the
+ // counting will start from span_end (exclusive) and end at span_start
+ // (inclusive).
+ int CountIgnoredSpanBoundaryCodepoints(
+ const UnicodeText::const_iterator& span_start,
+ const UnicodeText::const_iterator& span_end,
+ bool count_from_beginning) const;
+
// Finds the center token index in tokens vector, using the method defined
// in options_.
int FindCenterToken(CodepointSpan span,
@@ -227,6 +252,11 @@
void TokenizeSubstring(const UnicodeText& unicode_text, CodepointSpan span,
std::vector<Token>* result) const;
+ // Removes all tokens from tokens that are not on a line (defined by calling
+ // SplitContext on the context) to which span points.
+ void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
+ std::vector<Token>* tokens) const;
+
const TokenFeatureExtractor feature_extractor_;
// Codepoint ranges that define what codepoints are supported by the model.
@@ -240,6 +270,10 @@
std::vector<CodepointRange> internal_tokenizer_codepoint_ranges_;
private:
+ // Set of codepoints that will be stripped from beginning and end of
+ // predicted spans.
+ std::set<int32> ignored_span_boundary_codepoints_;
+
const FeatureProcessorOptions options_;
// Mapping between token selection spans and labels ids.
diff --git a/tests/feature-processor_test.cc b/smartselect/feature-processor_test.cc
similarity index 70%
rename from tests/feature-processor_test.cc
rename to smartselect/feature-processor_test.cc
index 4e27afc..9bee67a 100644
--- a/tests/feature-processor_test.cc
+++ b/smartselect/feature-processor_test.cc
@@ -25,6 +25,18 @@
using testing::ElementsAreArray;
using testing::FloatEq;
+class TestingFeatureProcessor : public FeatureProcessor {
+ public:
+ using FeatureProcessor::CountIgnoredSpanBoundaryCodepoints;
+ using FeatureProcessor::FeatureProcessor;
+ using FeatureProcessor::ICUTokenize;
+ using FeatureProcessor::IsCodepointInRanges;
+ using FeatureProcessor::SpanToLabel;
+ using FeatureProcessor::StripTokensFromOtherLines;
+ using FeatureProcessor::supported_codepoint_ranges_;
+ using FeatureProcessor::SupportedCodepointsRatio;
+};
+
TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesMiddle) {
std::vector<Token> tokens{Token("Hělló", 0, 5),
Token("fěěbař@google.com", 6, 23),
@@ -107,6 +119,10 @@
}
TEST(FeatureProcessorTest, KeepLineWithClickFirst) {
+ FeatureProcessorOptions options;
+ options.set_only_use_line_with_click(true);
+ TestingFeatureProcessor feature_processor(options);
+
const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
const CodepointSpan span = {0, 5};
// clang-format off
@@ -119,12 +135,16 @@
// clang-format on
// Keeps the first line.
- internal::StripTokensFromOtherLines(context, span, &tokens);
+ feature_processor.StripTokensFromOtherLines(context, span, &tokens);
EXPECT_THAT(tokens,
ElementsAreArray({Token("Fiřst", 0, 5), Token("Lině", 6, 10)}));
}
TEST(FeatureProcessorTest, KeepLineWithClickSecond) {
+ FeatureProcessorOptions options;
+ options.set_only_use_line_with_click(true);
+ TestingFeatureProcessor feature_processor(options);
+
const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
const CodepointSpan span = {18, 22};
// clang-format off
@@ -137,12 +157,16 @@
// clang-format on
// Keeps the first line.
- internal::StripTokensFromOtherLines(context, span, &tokens);
+ feature_processor.StripTokensFromOtherLines(context, span, &tokens);
EXPECT_THAT(tokens, ElementsAreArray(
{Token("Sěcond", 11, 17), Token("Lině", 18, 22)}));
}
TEST(FeatureProcessorTest, KeepLineWithClickThird) {
+ FeatureProcessorOptions options;
+ options.set_only_use_line_with_click(true);
+ TestingFeatureProcessor feature_processor(options);
+
const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
const CodepointSpan span = {24, 33};
// clang-format off
@@ -155,12 +179,16 @@
// clang-format on
// Keeps the first line.
- internal::StripTokensFromOtherLines(context, span, &tokens);
+ feature_processor.StripTokensFromOtherLines(context, span, &tokens);
EXPECT_THAT(tokens, ElementsAreArray(
{Token("Thiřd", 23, 28), Token("Lině", 29, 33)}));
}
TEST(FeatureProcessorTest, KeepLineWithClickSecondWithPipe) {
+ FeatureProcessorOptions options;
+ options.set_only_use_line_with_click(true);
+ TestingFeatureProcessor feature_processor(options);
+
const std::string context = "Fiřst Lině|Sěcond Lině\nThiřd Lině";
const CodepointSpan span = {18, 22};
// clang-format off
@@ -173,12 +201,16 @@
// clang-format on
// Keeps the first line.
- internal::StripTokensFromOtherLines(context, span, &tokens);
+ feature_processor.StripTokensFromOtherLines(context, span, &tokens);
EXPECT_THAT(tokens, ElementsAreArray(
{Token("Sěcond", 11, 17), Token("Lině", 18, 22)}));
}
TEST(FeatureProcessorTest, KeepLineWithCrosslineClick) {
+ FeatureProcessorOptions options;
+ options.set_only_use_line_with_click(true);
+ TestingFeatureProcessor feature_processor(options);
+
const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
const CodepointSpan span = {5, 23};
// clang-format off
@@ -191,23 +223,13 @@
// clang-format on
// Keeps the first line.
- internal::StripTokensFromOtherLines(context, span, &tokens);
+ feature_processor.StripTokensFromOtherLines(context, span, &tokens);
EXPECT_THAT(tokens, ElementsAreArray(
{Token("Fiřst", 0, 5), Token("Lině", 6, 10),
Token("Sěcond", 18, 23), Token("Lině", 19, 23),
Token("Thiřd", 23, 28), Token("Lině", 29, 33)}));
}
-class TestingFeatureProcessor : public FeatureProcessor {
- public:
- using FeatureProcessor::FeatureProcessor;
- using FeatureProcessor::SpanToLabel;
- using FeatureProcessor::SupportedCodepointsRatio;
- using FeatureProcessor::IsCodepointInRanges;
- using FeatureProcessor::ICUTokenize;
- using FeatureProcessor::supported_codepoint_ranges_;
-};
-
TEST(FeatureProcessorTest, SpanToLabel) {
FeatureProcessorOptions options;
options.set_context_size(1);
@@ -270,6 +292,68 @@
EXPECT_EQ(label2, label3);
}
+TEST(FeatureProcessorTest, SpanToLabelIgnoresPunctuation) {
+ FeatureProcessorOptions options;
+ options.set_context_size(1);
+ options.set_max_selection_span(1);
+ options.set_snap_label_span_boundaries_to_containing_tokens(false);
+
+ TokenizationCodepointRange* config =
+ options.add_tokenization_codepoint_config();
+ config->set_start(32);
+ config->set_end(33);
+ config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
+
+ TestingFeatureProcessor feature_processor(options);
+ std::vector<Token> tokens = feature_processor.Tokenize("one, two, three");
+ ASSERT_EQ(3, tokens.size());
+ int label;
+ ASSERT_TRUE(feature_processor.SpanToLabel({5, 8}, tokens, &label));
+ EXPECT_EQ(kInvalidLabel, label);
+ ASSERT_TRUE(feature_processor.SpanToLabel({5, 9}, tokens, &label));
+ EXPECT_NE(kInvalidLabel, label);
+ TokenSpan token_span;
+ feature_processor.LabelToTokenSpan(label, &token_span);
+ EXPECT_EQ(0, token_span.first);
+ EXPECT_EQ(0, token_span.second);
+
+ // Reconfigure with snapping enabled.
+ options.set_snap_label_span_boundaries_to_containing_tokens(true);
+ TestingFeatureProcessor feature_processor2(options);
+ int label2;
+ ASSERT_TRUE(feature_processor2.SpanToLabel({5, 8}, tokens, &label2));
+ EXPECT_EQ(label, label2);
+ ASSERT_TRUE(feature_processor2.SpanToLabel({6, 9}, tokens, &label2));
+ EXPECT_EQ(label, label2);
+ ASSERT_TRUE(feature_processor2.SpanToLabel({5, 9}, tokens, &label2));
+ EXPECT_EQ(label, label2);
+
+ // Cross a token boundary.
+ ASSERT_TRUE(feature_processor2.SpanToLabel({4, 9}, tokens, &label2));
+ EXPECT_EQ(kInvalidLabel, label2);
+ ASSERT_TRUE(feature_processor2.SpanToLabel({5, 10}, tokens, &label2));
+ EXPECT_EQ(kInvalidLabel, label2);
+
+ // Multiple tokens.
+ options.set_context_size(2);
+ options.set_max_selection_span(2);
+ TestingFeatureProcessor feature_processor3(options);
+ tokens = feature_processor3.Tokenize("zero, one, two, three, four");
+ ASSERT_TRUE(feature_processor3.SpanToLabel({6, 15}, tokens, &label2));
+ EXPECT_NE(kInvalidLabel, label2);
+ feature_processor3.LabelToTokenSpan(label2, &token_span);
+ EXPECT_EQ(1, token_span.first);
+ EXPECT_EQ(0, token_span.second);
+
+ int label3;
+ ASSERT_TRUE(feature_processor3.SpanToLabel({6, 14}, tokens, &label3));
+ EXPECT_EQ(label2, label3);
+ ASSERT_TRUE(feature_processor3.SpanToLabel({6, 13}, tokens, &label3));
+ EXPECT_EQ(label2, label3);
+ ASSERT_TRUE(feature_processor3.SpanToLabel({7, 13}, tokens, &label3));
+ EXPECT_EQ(label2, label3);
+}
+
TEST(FeatureProcessorTest, CenterTokenFromClick) {
int token_index;
@@ -610,5 +694,144 @@
// clang-format on
}
+TEST(FeatureProcessorTest, IgnoredSpanBoundaryCodepoints) {
+ FeatureProcessorOptions options;
+ options.add_ignored_span_boundary_codepoints('.');
+ options.add_ignored_span_boundary_codepoints(',');
+ options.add_ignored_span_boundary_codepoints('[');
+ options.add_ignored_span_boundary_codepoints(']');
+
+ TestingFeatureProcessor feature_processor(options);
+
+ const std::string text1_utf8 = "ěščř";
+ const UnicodeText text1 = UTF8ToUnicodeText(text1_utf8, /*do_copy=*/false);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text1.begin(), text1.end(),
+ /*count_from_beginning=*/true),
+ 0);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text1.begin(), text1.end(),
+ /*count_from_beginning=*/false),
+ 0);
+
+ const std::string text2_utf8 = ".,abčd";
+ const UnicodeText text2 = UTF8ToUnicodeText(text2_utf8, /*do_copy=*/false);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text2.begin(), text2.end(),
+ /*count_from_beginning=*/true),
+ 2);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text2.begin(), text2.end(),
+ /*count_from_beginning=*/false),
+ 0);
+
+ const std::string text3_utf8 = ".,abčd[]";
+ const UnicodeText text3 = UTF8ToUnicodeText(text3_utf8, /*do_copy=*/false);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text3.begin(), text3.end(),
+ /*count_from_beginning=*/true),
+ 2);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text3.begin(), text3.end(),
+ /*count_from_beginning=*/false),
+ 2);
+
+ const std::string text4_utf8 = "[abčd]";
+ const UnicodeText text4 = UTF8ToUnicodeText(text4_utf8, /*do_copy=*/false);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text4.begin(), text4.end(),
+ /*count_from_beginning=*/true),
+ 1);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text4.begin(), text4.end(),
+ /*count_from_beginning=*/false),
+ 1);
+
+ const std::string text5_utf8 = "";
+ const UnicodeText text5 = UTF8ToUnicodeText(text5_utf8, /*do_copy=*/false);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text5.begin(), text5.end(),
+ /*count_from_beginning=*/true),
+ 0);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text5.begin(), text5.end(),
+ /*count_from_beginning=*/false),
+ 0);
+
+ const std::string text6_utf8 = "012345ěščř";
+ const UnicodeText text6 = UTF8ToUnicodeText(text6_utf8, /*do_copy=*/false);
+ UnicodeText::const_iterator text6_begin = text6.begin();
+ std::advance(text6_begin, 6);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text6_begin, text6.end(),
+ /*count_from_beginning=*/true),
+ 0);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text6_begin, text6.end(),
+ /*count_from_beginning=*/false),
+ 0);
+
+ const std::string text7_utf8 = "012345.,ěščř";
+ const UnicodeText text7 = UTF8ToUnicodeText(text7_utf8, /*do_copy=*/false);
+ UnicodeText::const_iterator text7_begin = text7.begin();
+ std::advance(text7_begin, 6);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text7_begin, text7.end(),
+ /*count_from_beginning=*/true),
+ 2);
+ UnicodeText::const_iterator text7_end = text7.begin();
+ std::advance(text7_end, 8);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text7.begin(), text7_end,
+ /*count_from_beginning=*/false),
+ 2);
+
+ // Test not stripping.
+ EXPECT_EQ(feature_processor.StripBoundaryCodepoints(
+ "Hello [[[Wořld]] or not?", {0, 24}),
+ std::make_pair(0, 24));
+ // Test basic stripping.
+ EXPECT_EQ(feature_processor.StripBoundaryCodepoints(
+ "Hello [[[Wořld]] or not?", {6, 16}),
+ std::make_pair(9, 14));
+ // Test stripping when everything is stripped.
+ EXPECT_EQ(
+ feature_processor.StripBoundaryCodepoints("Hello [[[]] or not?", {6, 11}),
+ std::make_pair(6, 6));
+ // Test stripping empty string.
+ EXPECT_EQ(feature_processor.StripBoundaryCodepoints("", {0, 0}),
+ std::make_pair(0, 0));
+}
+
+TEST(FeatureProcessorTest, CodepointSpanToTokenSpan) {
+ const std::vector<Token> tokens{Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)};
+
+ // Spans matching the tokens exactly.
+ EXPECT_EQ(TokenSpan(0, 1), CodepointSpanToTokenSpan(tokens, {0, 5}));
+ EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {6, 23}));
+ EXPECT_EQ(TokenSpan(2, 3), CodepointSpanToTokenSpan(tokens, {24, 29}));
+ EXPECT_EQ(TokenSpan(0, 2), CodepointSpanToTokenSpan(tokens, {0, 23}));
+ EXPECT_EQ(TokenSpan(1, 3), CodepointSpanToTokenSpan(tokens, {6, 29}));
+ EXPECT_EQ(TokenSpan(0, 3), CodepointSpanToTokenSpan(tokens, {0, 29}));
+
+ // Snapping to containing tokens has no effect.
+ EXPECT_EQ(TokenSpan(0, 1), CodepointSpanToTokenSpan(tokens, {0, 5}, true));
+ EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {6, 23}, true));
+ EXPECT_EQ(TokenSpan(2, 3), CodepointSpanToTokenSpan(tokens, {24, 29}, true));
+ EXPECT_EQ(TokenSpan(0, 2), CodepointSpanToTokenSpan(tokens, {0, 23}, true));
+ EXPECT_EQ(TokenSpan(1, 3), CodepointSpanToTokenSpan(tokens, {6, 29}, true));
+ EXPECT_EQ(TokenSpan(0, 3), CodepointSpanToTokenSpan(tokens, {0, 29}, true));
+
+ // Span boundaries inside tokens.
+ EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {1, 28}));
+ EXPECT_EQ(TokenSpan(0, 3), CodepointSpanToTokenSpan(tokens, {1, 28}, true));
+
+ // Tokens adjacent to the span, but not overlapping.
+ EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {5, 24}));
+ EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {5, 24}, true));
+}
+
} // namespace
} // namespace libtextclassifier
diff --git a/smartselect/model-params.cc b/smartselect/model-params.cc
index 9a31bab..65c4f93 100644
--- a/smartselect/model-params.cc
+++ b/smartselect/model-params.cc
@@ -43,6 +43,7 @@
reader.trimmed_proto().GetExtension(feature_processor_extension_id);
// If no tokenization codepoint config is present, tokenize on space.
+ // TODO(zilka): Remove the default config.
if (feature_processor_options.tokenization_codepoint_config_size() == 0) {
TokenizationCodepointRange* config;
// New line character.
@@ -67,42 +68,16 @@
if (reader.trimmed_proto().HasExtension(selection_options_extension_id)) {
selection_options =
reader.trimmed_proto().GetExtension(selection_options_extension_id);
- } else {
- // Default values when SelectionModelOptions is not present.
- for (const auto codepoint_pair : std::vector<std::pair<int, int>>(
- {{33, 35}, {37, 39}, {42, 42}, {44, 47},
- {58, 59}, {63, 64}, {91, 93}, {95, 95},
- {123, 123}, {125, 125}, {161, 161}, {171, 171},
- {183, 183}, {187, 187}, {191, 191}, {894, 894},
- {903, 903}, {1370, 1375}, {1417, 1418}, {1470, 1470},
- {1472, 1472}, {1475, 1475}, {1478, 1478}, {1523, 1524},
- {1548, 1549}, {1563, 1563}, {1566, 1567}, {1642, 1645},
- {1748, 1748}, {1792, 1805}, {2404, 2405}, {2416, 2416},
- {3572, 3572}, {3663, 3663}, {3674, 3675}, {3844, 3858},
- {3898, 3901}, {3973, 3973}, {4048, 4049}, {4170, 4175},
- {4347, 4347}, {4961, 4968}, {5741, 5742}, {5787, 5788},
- {5867, 5869}, {5941, 5942}, {6100, 6102}, {6104, 6106},
- {6144, 6154}, {6468, 6469}, {6622, 6623}, {6686, 6687},
- {8208, 8231}, {8240, 8259}, {8261, 8273}, {8275, 8286},
- {8317, 8318}, {8333, 8334}, {9001, 9002}, {9140, 9142},
- {10088, 10101}, {10181, 10182}, {10214, 10219}, {10627, 10648},
- {10712, 10715}, {10748, 10749}, {11513, 11516}, {11518, 11519},
- {11776, 11799}, {11804, 11805}, {12289, 12291}, {12296, 12305},
- {12308, 12319}, {12336, 12336}, {12349, 12349}, {12448, 12448},
- {12539, 12539}, {64830, 64831}, {65040, 65049}, {65072, 65106},
- {65108, 65121}, {65123, 65123}, {65128, 65128}, {65130, 65131},
- {65281, 65283}, {65285, 65290}, {65292, 65295}, {65306, 65307},
- {65311, 65312}, {65339, 65341}, {65343, 65343}, {65371, 65371},
- {65373, 65373}, {65375, 65381}, {65792, 65793}, {66463, 66463},
- {68176, 68184}})) {
- for (int i = codepoint_pair.first; i <= codepoint_pair.second; i++) {
- selection_options.add_punctuation_to_strip(i);
- }
- selection_options.set_strip_punctuation(true);
- selection_options.set_enforce_symmetry(true);
- selection_options.set_symmetry_context_size(
- feature_processor_options.context_size() * 2);
+
+ // For backward compatibility with the current models.
+ if (!feature_processor_options.ignored_span_boundary_codepoints_size()) {
+ *feature_processor_options.mutable_ignored_span_boundary_codepoints() =
+ selection_options.deprecated_punctuation_to_strip();
}
+ } else {
+ selection_options.set_enforce_symmetry(true);
+ selection_options.set_symmetry_context_size(
+ feature_processor_options.context_size() * 2);
}
SharingModelOptions sharing_options;
diff --git a/smartselect/model-parser.cc b/smartselect/model-parser.cc
new file mode 100644
index 0000000..0cf05e3
--- /dev/null
+++ b/smartselect/model-parser.cc
@@ -0,0 +1,91 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "smartselect/model-parser.h"
+#include "util/base/endian.h"
+
+namespace libtextclassifier {
+namespace {
+
+// Small helper class for parsing the merged model format.
+// The merged model consists of interleaved <int32 data_size, char* data>
+// segments.
+class MergedModelParser {
+ public:
+ MergedModelParser(const void* addr, const int size)
+ : addr_(reinterpret_cast<const char*>(addr)), size_(size), pos_(addr_) {}
+
+ bool ReadBytesAndAdvance(int num_bytes, const char** result) {
+ const char* read_addr = pos_;
+ if (Advance(num_bytes)) {
+ *result = read_addr;
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ bool ReadInt32AndAdvance(int* result) {
+ const char* read_addr = pos_;
+ if (Advance(sizeof(int))) {
+ *result =
+ LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(read_addr));
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ bool IsDone() { return pos_ == addr_ + size_; }
+
+ private:
+ bool Advance(int num_bytes) {
+ pos_ += num_bytes;
+ return pos_ <= addr_ + size_;
+ }
+
+ const char* addr_;
+ const int size_;
+ const char* pos_;
+};
+
+} // namespace
+
+bool ParseMergedModel(const void* addr, const int size,
+ const char** selection_model, int* selection_model_length,
+ const char** sharing_model, int* sharing_model_length) {
+ MergedModelParser parser(addr, size);
+
+ if (!parser.ReadInt32AndAdvance(selection_model_length)) {
+ return false;
+ }
+
+ if (!parser.ReadBytesAndAdvance(*selection_model_length, selection_model)) {
+ return false;
+ }
+
+ if (!parser.ReadInt32AndAdvance(sharing_model_length)) {
+ return false;
+ }
+
+ if (!parser.ReadBytesAndAdvance(*sharing_model_length, sharing_model)) {
+ return false;
+ }
+
+ return parser.IsDone();
+}
+
+} // namespace libtextclassifier
diff --git a/tests/functions.cc b/smartselect/model-parser.h
similarity index 62%
copy from tests/functions.cc
copy to smartselect/model-parser.h
index 8ea5a8d..801262f 100644
--- a/tests/functions.cc
+++ b/smartselect/model-parser.h
@@ -14,16 +14,16 @@
* limitations under the License.
*/
-#include "tests/functions.h"
-
-#include "common/registry.h"
+#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARSER_H_
+#define LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARSER_H_
namespace libtextclassifier {
-namespace nlp_core {
-TC_DEFINE_CLASS_REGISTRY_NAME("function", functions::Function);
+// Parse a merged model image.
+bool ParseMergedModel(const void* addr, const int size,
+ const char** selection_model, int* selection_model_length,
+ const char** sharing_model, int* sharing_model_length);
-TC_DEFINE_CLASS_REGISTRY_NAME("int-function", functions::IntFunction);
-
-} // namespace nlp_core
} // namespace libtextclassifier
+
+#endif // LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARSER_H_
diff --git a/smartselect/text-classification-model.cc b/smartselect/text-classification-model.cc
index dee8f8b..e7ae09c 100644
--- a/smartselect/text-classification-model.cc
+++ b/smartselect/text-classification-model.cc
@@ -16,6 +16,7 @@
#include "smartselect/text-classification-model.h"
+#include <cctype>
#include <cmath>
#include <iterator>
#include <numeric>
@@ -26,10 +27,14 @@
#include "common/memory_image/memory-image-reader.h"
#include "common/mmap.h"
#include "common/softmax.h"
+#include "smartselect/model-parser.h"
#include "smartselect/text-classification-model.pb.h"
#include "util/base/logging.h"
#include "util/utf8/unicodetext.h"
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
+#include "unicode/regex.h"
#include "unicode/uchar.h"
+#endif
namespace libtextclassifier {
@@ -49,65 +54,71 @@
const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false);
for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) {
if (i >= selection_indices.first && i < selection_indices.second &&
- u_isdigit(*it)) {
+ isdigit(*it)) {
++count;
}
}
return count;
}
-} // namespace
-
-CodepointSpan TextClassificationModel::StripPunctuation(
- CodepointSpan selection, const std::string& context) const {
- UnicodeText context_unicode = UTF8ToUnicodeText(context, /*do_copy=*/false);
- int context_length =
- std::distance(context_unicode.begin(), context_unicode.end());
-
- // Check that the indices are valid.
- if (selection.first < 0 || selection.first > context_length ||
- selection.second < 0 || selection.second > context_length) {
- return selection;
- }
-
- // Move the left border until we encounter a non-punctuation character.
- UnicodeText::const_iterator it_from_begin = context_unicode.begin();
- std::advance(it_from_begin, selection.first);
- for (; punctuation_to_strip_.find(*it_from_begin) !=
- punctuation_to_strip_.end();
- ++it_from_begin, ++selection.first) {
- }
-
- // Unless we are already at the end, move the right border until we encounter
- // a non-punctuation character.
- UnicodeText::const_iterator it_from_end = context_unicode.begin();
- std::advance(it_from_end, selection.second);
- if (it_from_begin != it_from_end) {
- --it_from_end;
- for (; punctuation_to_strip_.find(*it_from_end) !=
- punctuation_to_strip_.end();
- --it_from_end, --selection.second) {
- }
- return selection;
- } else {
- // When the token is all punctuation.
- return {0, 0};
- }
+std::string ExtractSelection(const std::string& context,
+ CodepointSpan selection_indices) {
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+ auto selection_begin = context_unicode.begin();
+ std::advance(selection_begin, selection_indices.first);
+ auto selection_end = context_unicode.begin();
+ std::advance(selection_end, selection_indices.second);
+ return UnicodeText::UTF8Substring(selection_begin, selection_end);
}
-TextClassificationModel::TextClassificationModel(int fd) : mmap_(fd) {
- initialized_ = LoadModels(mmap_.handle());
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
+bool MatchesRegex(const icu::RegexPattern* regex, const std::string& context) {
+ const icu::UnicodeString unicode_context(context.c_str(), context.size(),
+ "utf-8");
+ UErrorCode status = U_ZERO_ERROR;
+ std::unique_ptr<icu::RegexMatcher> matcher(
+ regex->matcher(unicode_context, status));
+ return matcher->matches(0 /* start */, status);
+}
+#endif
+
+} // namespace
+
+TextClassificationModel::TextClassificationModel(const std::string& path)
+ : mmap_(new nlp_core::ScopedMmap(path)) {
+ InitFromMmap();
+}
+
+TextClassificationModel::TextClassificationModel(int fd)
+ : mmap_(new nlp_core::ScopedMmap(fd)) {
+ InitFromMmap();
+}
+
+TextClassificationModel::TextClassificationModel(int fd, int offset, int size)
+ : mmap_(new nlp_core::ScopedMmap(fd, offset, size)) {
+ InitFromMmap();
+}
+
+TextClassificationModel::TextClassificationModel(const void* addr, int size) {
+ initialized_ = LoadModels(addr, size);
if (!initialized_) {
TC_LOG(ERROR) << "Failed to load models";
return;
}
+}
- selection_options_ = selection_params_->GetSelectionModelOptions();
- for (const int codepoint : selection_options_.punctuation_to_strip()) {
- punctuation_to_strip_.insert(codepoint);
+void TextClassificationModel::InitFromMmap() {
+ if (!mmap_->handle().ok()) {
+ return;
}
- sharing_options_ = selection_params_->GetSharingModelOptions();
+ initialized_ =
+ LoadModels(mmap_->handle().start(), mmap_->handle().num_bytes());
+ if (!initialized_) {
+ TC_LOG(ERROR) << "Failed to load models";
+ return;
+ }
}
namespace {
@@ -151,40 +162,48 @@
};
}
-void ParseMergedModel(const MmapHandle& mmap_handle,
- const char** selection_model, int* selection_model_length,
- const char** sharing_model, int* sharing_model_length) {
- // Read the length of the selection model.
- const char* model_data = reinterpret_cast<const char*>(mmap_handle.start());
- *selection_model_length =
- LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(model_data));
- model_data += sizeof(*selection_model_length);
- *selection_model = model_data;
- model_data += *selection_model_length;
-
- *sharing_model_length =
- LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(model_data));
- model_data += sizeof(*sharing_model_length);
- *sharing_model = model_data;
-}
-
} // namespace
-bool TextClassificationModel::LoadModels(const MmapHandle& mmap_handle) {
- if (!mmap_handle.ok()) {
- return false;
+void TextClassificationModel::InitializeSharingRegexPatterns(
+ const std::vector<SharingModelOptions::RegexPattern>& patterns) {
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
+ // Initialize pattern recognizers.
+ for (const auto& regex_pattern : patterns) {
+ UErrorCode status = U_ZERO_ERROR;
+ std::unique_ptr<icu::RegexPattern> compiled_pattern(
+ icu::RegexPattern::compile(
+ icu::UnicodeString(regex_pattern.pattern().c_str(),
+ regex_pattern.pattern().size(), "utf-8"),
+ 0 /* flags */, status));
+ if (U_FAILURE(status)) {
+ TC_LOG(WARNING) << "Failed to load pattern" << regex_pattern.pattern();
+ } else {
+ regex_patterns_.push_back(
+ {regex_pattern.collection_name(), std::move(compiled_pattern)});
+ }
}
+#else
+ if (!patterns.empty()) {
+ TC_LOG(WARNING) << "ICU not supported regexp matchers ignored.";
+ }
+#endif
+}
+bool TextClassificationModel::LoadModels(const void* addr, int size) {
const char *selection_model, *sharing_model;
int selection_model_length, sharing_model_length;
- ParseMergedModel(mmap_handle, &selection_model, &selection_model_length,
- &sharing_model, &sharing_model_length);
+ if (!ParseMergedModel(addr, size, &selection_model, &selection_model_length,
+ &sharing_model, &sharing_model_length)) {
+ TC_LOG(ERROR) << "Couldn't parse the model.";
+ return false;
+ }
selection_params_.reset(
ModelParamsBuilder(selection_model, selection_model_length, nullptr));
if (!selection_params_.get()) {
return false;
}
+ selection_options_ = selection_params_->GetSelectionModelOptions();
selection_network_.reset(new EmbeddingNetwork(selection_params_.get()));
selection_feature_processor_.reset(
new FeatureProcessor(selection_params_->GetFeatureProcessorOptions()));
@@ -197,12 +216,17 @@
if (!sharing_params_.get()) {
return false;
}
+ sharing_options_ = selection_params_->GetSharingModelOptions();
sharing_network_.reset(new EmbeddingNetwork(sharing_params_.get()));
sharing_feature_processor_.reset(
new FeatureProcessor(sharing_params_->GetFeatureProcessorOptions()));
sharing_feature_fn_ = CreateFeatureVectorFn(
*sharing_network_, sharing_network_->EmbeddingSize(0));
+ InitializeSharingRegexPatterns(std::vector<SharingModelOptions::RegexPattern>(
+ sharing_options_.regex_pattern().begin(),
+ sharing_options_.regex_pattern().end()));
+
return true;
}
@@ -215,8 +239,12 @@
const char *selection_model, *sharing_model;
int selection_model_length, sharing_model_length;
- ParseMergedModel(mmap.handle(), &selection_model, &selection_model_length,
- &sharing_model, &sharing_model_length);
+ if (!ParseMergedModel(mmap.handle().start(), mmap.handle().num_bytes(),
+ &selection_model, &selection_model_length,
+ &sharing_model, &sharing_model_length)) {
+ TC_LOG(ERROR) << "Couldn't parse merged model.";
+ return false;
+ }
MemoryImageReader<EmbeddingNetworkProto> reader(selection_model,
selection_model_length);
@@ -245,14 +273,14 @@
CreateFeatureVectorFn(network, embedding_size),
embedding_size + feature_processor.DenseFeaturesCount(), &tokens,
&click_pos, &cached_features)) {
- TC_LOG(ERROR) << "Could not extract features.";
+ TC_VLOG(1) << "Could not extract features.";
return {};
}
VectorSpan<float> features;
VectorSpan<Token> output_tokens;
if (!cached_features->Get(click_pos, &features, &output_tokens)) {
- TC_LOG(ERROR) << "Could not extract features.";
+ TC_VLOG(1) << "Could not extract features.";
return {};
}
@@ -269,6 +297,104 @@
return scores;
}
+namespace {
+
+// Returns true if given codepoint is contained in the given span in context.
+bool IsCodepointInSpan(const char32 codepoint, const std::string& context,
+ const CodepointSpan span) {
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+
+ auto begin_it = context_unicode.begin();
+ std::advance(begin_it, span.first);
+ auto end_it = context_unicode.begin();
+ std::advance(end_it, span.second);
+
+ return std::find(begin_it, end_it, codepoint) != end_it;
+}
+
+// Returns the first codepoint of the span.
+char32 FirstSpanCodepoint(const std::string& context,
+ const CodepointSpan span) {
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+
+ auto it = context_unicode.begin();
+ std::advance(it, span.first);
+ return *it;
+}
+
+// Returns the last codepoint of the span.
+char32 LastSpanCodepoint(const std::string& context, const CodepointSpan span) {
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+
+ auto it = context_unicode.begin();
+ std::advance(it, span.second - 1);
+ return *it;
+}
+
+} // namespace
+
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
+
+namespace {
+
+bool IsOpenBracket(const char32 codepoint) {
+ return u_getIntPropertyValue(codepoint, UCHAR_BIDI_PAIRED_BRACKET_TYPE) ==
+ U_BPT_OPEN;
+}
+
+bool IsClosingBracket(const char32 codepoint) {
+ return u_getIntPropertyValue(codepoint, UCHAR_BIDI_PAIRED_BRACKET_TYPE) ==
+ U_BPT_CLOSE;
+}
+
+} // namespace
+
+// If the first or the last codepoint of the given span is a bracket, the
+// bracket is stripped if the span does not contain its corresponding paired
+// version.
+CodepointSpan StripUnpairedBrackets(const std::string& context,
+ CodepointSpan span) {
+ if (context.empty()) {
+ return span;
+ }
+
+ const char32 begin_char = FirstSpanCodepoint(context, span);
+
+ const char32 paired_begin_char = u_getBidiPairedBracket(begin_char);
+ if (paired_begin_char != begin_char) {
+ if (!IsOpenBracket(begin_char) ||
+ !IsCodepointInSpan(paired_begin_char, context, span)) {
+ ++span.first;
+ }
+ }
+
+ if (span.first == span.second) {
+ return span;
+ }
+
+ const char32 end_char = LastSpanCodepoint(context, span);
+ const char32 paired_end_char = u_getBidiPairedBracket(end_char);
+ if (paired_end_char != end_char) {
+ if (!IsClosingBracket(end_char) ||
+ !IsCodepointInSpan(paired_end_char, context, span)) {
+ --span.second;
+ }
+ }
+
+ // Should not happen, but let's make sure.
+ if (span.first > span.second) {
+ TC_LOG(WARNING) << "Inverse indices result: " << span.first << ", "
+ << span.second;
+ span.second = span.first;
+ }
+
+ return span;
+}
+#endif
+
CodepointSpan TextClassificationModel::SuggestSelection(
const std::string& context, CodepointSpan click_indices) const {
if (!initialized_) {
@@ -276,19 +402,15 @@
return click_indices;
}
- if (std::get<0>(click_indices) >= std::get<1>(click_indices)) {
- TC_LOG(ERROR) << "Trying to run SuggestSelection with invalid indices:"
- << std::get<0>(click_indices) << " "
- << std::get<1>(click_indices);
- return click_indices;
- }
+ const int context_codepoint_size =
+ UTF8ToUnicodeText(context, /*do_copy=*/false).size();
- const UnicodeText context_unicode =
- UTF8ToUnicodeText(context, /*do_copy=*/false);
- const int context_length =
- std::distance(context_unicode.begin(), context_unicode.end());
- if (std::get<0>(click_indices) >= context_length ||
- std::get<1>(click_indices) > context_length) {
+ if (click_indices.first < 0 || click_indices.second < 0 ||
+ click_indices.first >= context_codepoint_size ||
+ click_indices.second > context_codepoint_size ||
+ click_indices.first >= click_indices.second) {
+ TC_VLOG(1) << "Trying to run SuggestSelection with invalid indices: "
+ << click_indices.first << " " << click_indices.second;
return click_indices;
}
@@ -300,28 +422,42 @@
std::tie(result, score) = SuggestSelectionInternal(context, click_indices);
}
- if (selection_options_.strip_punctuation()) {
- result = StripPunctuation(result, context);
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
+ if (selection_options_.strip_unpaired_brackets()) {
+ const CodepointSpan stripped_result =
+ StripUnpairedBrackets(context, result);
+ if (stripped_result.first != stripped_result.second) {
+ result = stripped_result;
+ }
}
+#endif
return result;
}
namespace {
-std::pair<CodepointSpan, float> BestSelectionSpan(
- CodepointSpan original_click_indices, const std::vector<float>& scores,
- const std::vector<CodepointSpan>& selection_label_spans) {
+int BestPrediction(const std::vector<float>& scores) {
if (!scores.empty()) {
const int prediction =
std::max_element(scores.begin(), scores.end()) - scores.begin();
+ return prediction;
+ } else {
+ return kInvalidLabel;
+ }
+}
+
+std::pair<CodepointSpan, float> BestSelectionSpan(
+ CodepointSpan original_click_indices, const std::vector<float>& scores,
+ const std::vector<CodepointSpan>& selection_label_spans) {
+ const int prediction = BestPrediction(scores);
+ if (prediction != kInvalidLabel) {
std::pair<CodepointIndex, CodepointIndex> selection =
selection_label_spans[prediction];
if (selection.first == kInvalidIndex || selection.second == kInvalidIndex) {
- TC_LOG(ERROR) << "Invalid indices predicted, returning input: "
- << prediction << " " << selection.first << " "
- << selection.second;
+ TC_VLOG(1) << "Invalid indices predicted, returning input: " << prediction
+ << " " << selection.first << " " << selection.second;
return {original_click_indices, -1.0};
}
@@ -367,86 +503,17 @@
CodepointSpan TextClassificationModel::SuggestSelectionSymmetrical(
const std::string& context, CodepointSpan click_indices) const {
const int symmetry_context_size = selection_options_.symmetry_context_size();
- std::vector<Token> tokens;
- std::unique_ptr<CachedFeatures> cached_features;
- int click_index;
- int embedding_size = selection_network_->EmbeddingSize(0);
- if (!selection_feature_processor_->ExtractFeatures(
- context, click_indices, /*relative_click_span=*/
- {symmetry_context_size, symmetry_context_size + 1},
- selection_feature_fn_,
- embedding_size + selection_feature_processor_->DenseFeaturesCount(),
- &tokens, &click_index, &cached_features)) {
- TC_LOG(ERROR) << "Couldn't ExtractFeatures.";
- return click_indices;
- }
-
- // Scan in the symmetry context for selection span proposals.
- std::vector<std::pair<CodepointSpan, float>> proposals;
-
- for (int i = -symmetry_context_size; i < symmetry_context_size + 1; ++i) {
- const int token_index = click_index + i;
- if (token_index >= 0 && token_index < tokens.size() &&
- !tokens[token_index].is_padding) {
- float score;
- VectorSpan<float> features;
- VectorSpan<Token> output_tokens;
-
- CodepointSpan span;
- if (cached_features->Get(token_index, &features, &output_tokens)) {
- std::vector<float> scores;
- selection_network_->ComputeLogits(features, &scores);
-
- std::vector<CodepointSpan> selection_label_spans;
- if (selection_feature_processor_->SelectionLabelSpans(
- output_tokens, &selection_label_spans)) {
- scores = nlp_core::ComputeSoftmax(scores);
- std::tie(span, score) =
- BestSelectionSpan(click_indices, scores, selection_label_spans);
- if (span.first != kInvalidIndex && span.second != kInvalidIndex &&
- score >= 0) {
- proposals.push_back({span, score});
- }
- }
- }
+ std::vector<CodepointSpan> chunks = Chunk(
+ context, click_indices, {symmetry_context_size, symmetry_context_size});
+ for (const CodepointSpan& chunk : chunks) {
+ // If chunk and click indices have an overlap, return the chunk.
+ if (!(click_indices.first >= chunk.second ||
+ click_indices.second <= chunk.first)) {
+ return chunk;
}
}
- // Sort selection span proposals by their respective probabilities.
- std::sort(
- proposals.begin(), proposals.end(),
- [](std::pair<CodepointSpan, float> a, std::pair<CodepointSpan, float> b) {
- return a.second > b.second;
- });
-
- // Go from the highest-scoring proposal and claim tokens. Tokens are marked as
- // claimed by the higher-scoring selection proposals, so that the
- // lower-scoring ones cannot use them. Returns the selection proposal if it
- // contains the clicked token.
- std::vector<int> used_tokens(tokens.size(), 0);
- for (auto span_result : proposals) {
- TokenSpan span = CodepointSpanToTokenSpan(tokens, span_result.first);
- if (span.first != kInvalidIndex && span.second != kInvalidIndex) {
- bool feasible = true;
- for (int i = span.first; i < span.second; i++) {
- if (used_tokens[i] != 0) {
- feasible = false;
- break;
- }
- }
-
- if (feasible) {
- if (span.first <= click_index && span.second > click_index) {
- return {span_result.first.first, span_result.first.second};
- }
- for (int i = span.first; i < span.second; i++) {
- used_tokens[i] = 1;
- }
- }
- }
- }
-
- return {click_indices.first, click_indices.second};
+ return click_indices;
}
std::vector<std::pair<std::string, float>>
@@ -459,9 +526,9 @@
}
if (std::get<0>(selection_indices) >= std::get<1>(selection_indices)) {
- TC_LOG(ERROR) << "Trying to run ClassifyText with invalid indices: "
- << std::get<0>(selection_indices) << " "
- << std::get<1>(selection_indices);
+ TC_VLOG(1) << "Trying to run ClassifyText with invalid indices: "
+ << std::get<0>(selection_indices) << " "
+ << std::get<1>(selection_indices);
return {};
}
@@ -475,21 +542,31 @@
return {{kEmailHintCollection, 1.0}};
}
+ // Check whether any of the regular expressions match.
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
+ const std::string selection_text =
+ ExtractSelection(context, selection_indices);
+ for (const CompiledRegexPattern& regex_pattern : regex_patterns_) {
+ if (MatchesRegex(regex_pattern.pattern.get(), selection_text)) {
+ return {{regex_pattern.collection_name, 1.0}};
+ }
+ }
+#endif
+
EmbeddingNetwork::Vector scores =
InferInternal(context, selection_indices, *sharing_feature_processor_,
*sharing_network_, sharing_feature_fn_, nullptr);
if (scores.empty() ||
scores.size() != sharing_feature_processor_->NumCollections()) {
- TC_LOG(ERROR) << "Using default class: scores.size() = " << scores.size();
+ TC_VLOG(1) << "Using default class: scores.size() = " << scores.size();
return {};
}
scores = nlp_core::ComputeSoftmax(scores);
- std::vector<std::pair<std::string, float>> result;
+ std::vector<std::pair<std::string, float>> result(scores.size());
for (int i = 0; i < scores.size(); i++) {
- result.push_back(
- {sharing_feature_processor_->LabelToCollection(i), scores[i]});
+ result[i] = {sharing_feature_processor_->LabelToCollection(i), scores[i]};
}
std::sort(result.begin(), result.end(),
[](const std::pair<std::string, float>& a,
@@ -509,4 +586,156 @@
return result;
}
+std::vector<CodepointSpan> TextClassificationModel::Chunk(
+ const std::string& context, CodepointSpan click_span,
+ TokenSpan relative_click_span) const {
+ std::unique_ptr<CachedFeatures> cached_features;
+ std::vector<Token> tokens;
+ int click_index;
+ int embedding_size = selection_network_->EmbeddingSize(0);
+ if (!selection_feature_processor_->ExtractFeatures(
+ context, click_span, relative_click_span, selection_feature_fn_,
+ embedding_size + selection_feature_processor_->DenseFeaturesCount(),
+ &tokens, &click_index, &cached_features)) {
+ TC_VLOG(1) << "Couldn't ExtractFeatures.";
+ return {};
+ }
+
+ int first_token;
+ int last_token;
+ if (relative_click_span.first == kInvalidIndex ||
+ relative_click_span.second == kInvalidIndex) {
+ first_token = 0;
+ last_token = tokens.size();
+ } else {
+ first_token = click_index - relative_click_span.first;
+ last_token = click_index + relative_click_span.second + 1;
+ }
+
+ struct SelectionProposal {
+ int label;
+ int token_index;
+ CodepointSpan span;
+ float score;
+ };
+
+ // Scan in the symmetry context for selection span proposals.
+ std::vector<SelectionProposal> proposals;
+ for (int token_index = first_token; token_index < last_token; ++token_index) {
+ if (token_index < 0 || token_index >= tokens.size() ||
+ tokens[token_index].is_padding) {
+ continue;
+ }
+
+ float score;
+ VectorSpan<float> features;
+ VectorSpan<Token> output_tokens;
+ std::vector<CodepointSpan> selection_label_spans;
+ CodepointSpan span;
+ if (cached_features->Get(token_index, &features, &output_tokens) &&
+ selection_feature_processor_->SelectionLabelSpans(
+ output_tokens, &selection_label_spans)) {
+ // Add an implicit proposal for each token to be by itself. Every
+ // token should be now represented in the results.
+ proposals.push_back(
+ SelectionProposal{0, token_index, selection_label_spans[0], 0.0});
+
+ std::vector<float> scores;
+ selection_network_->ComputeLogits(features, &scores);
+
+ scores = nlp_core::ComputeSoftmax(scores);
+ std::tie(span, score) = BestSelectionSpan({kInvalidIndex, kInvalidIndex},
+ scores, selection_label_spans);
+ if (span.first != kInvalidIndex && span.second != kInvalidIndex &&
+ score >= 0) {
+ const int prediction = BestPrediction(scores);
+ proposals.push_back(
+ SelectionProposal{prediction, token_index, span, score});
+ }
+ } else {
+ // Add an implicit proposal for each token to be by itself. Every token
+ // should be now represented in the results.
+ proposals.push_back(SelectionProposal{
+ 0,
+ token_index,
+ {tokens[token_index].start, tokens[token_index].end},
+ 0.0});
+ }
+ }
+
+ // Sort selection span proposals by their respective probabilities.
+ std::sort(proposals.begin(), proposals.end(),
+ [](const SelectionProposal& a, const SelectionProposal& b) {
+ return a.score > b.score;
+ });
+
+ // Go from the highest-scoring proposal and claim tokens. Tokens are marked as
+ // claimed by the higher-scoring selection proposals, so that the
+ // lower-scoring ones cannot use them. Returns the selection proposal if it
+ // contains the clicked token.
+ std::vector<CodepointSpan> result;
+ std::vector<bool> token_used(tokens.size(), false);
+ for (const SelectionProposal& proposal : proposals) {
+ const int predicted_label = proposal.label;
+ TokenSpan relative_span;
+ if (!selection_feature_processor_->LabelToTokenSpan(predicted_label,
+ &relative_span)) {
+ continue;
+ }
+ TokenSpan span;
+ span.first = proposal.token_index - relative_span.first;
+ span.second = proposal.token_index + relative_span.second + 1;
+
+ if (span.first != kInvalidIndex && span.second != kInvalidIndex) {
+ bool feasible = true;
+ for (int i = span.first; i < span.second; i++) {
+ if (token_used[i]) {
+ feasible = false;
+ break;
+ }
+ }
+
+ if (feasible) {
+ result.push_back(proposal.span);
+ for (int i = span.first; i < span.second; i++) {
+ token_used[i] = true;
+ }
+ }
+ }
+ }
+
+ std::sort(result.begin(), result.end(),
+ [](const CodepointSpan& a, const CodepointSpan& b) {
+ return a.first < b.first;
+ });
+
+ return result;
+}
+
+std::vector<TextClassificationModel::AnnotatedSpan>
+TextClassificationModel::Annotate(const std::string& context) const {
+ std::vector<CodepointSpan> chunks;
+ const UnicodeText context_unicode = UTF8ToUnicodeText(context,
+ /*do_copy=*/false);
+ for (const UnicodeTextRange& line :
+ selection_feature_processor_->SplitContext(context_unicode)) {
+ const std::vector<CodepointSpan> local_chunks =
+ Chunk(UnicodeText::UTF8Substring(line.first, line.second),
+ /*click_span=*/{kInvalidIndex, kInvalidIndex},
+ /*relative_click_span=*/{kInvalidIndex, kInvalidIndex});
+ const int offset = std::distance(context_unicode.begin(), line.first);
+ for (CodepointSpan chunk : local_chunks) {
+ chunks.push_back({chunk.first + offset, chunk.second + offset});
+ }
+ }
+
+ std::vector<TextClassificationModel::AnnotatedSpan> result;
+ for (const CodepointSpan& chunk : chunks) {
+ result.emplace_back();
+ result.back().span = chunk;
+ result.back().classification = ClassifyText(context, chunk);
+ }
+ return result;
+}
+
} // namespace libtextclassifier
diff --git a/smartselect/text-classification-model.h b/smartselect/text-classification-model.h
index 522372c..d0df193 100644
--- a/smartselect/text-classification-model.h
+++ b/smartselect/text-classification-model.h
@@ -23,7 +23,6 @@
#include <set>
#include <string>
-#include "base.h"
#include "common/embedding-network.h"
#include "common/feature-extractor.h"
#include "common/memory_image/embedding-network-params-from-image.h"
@@ -38,10 +37,33 @@
// SmartSelection/Sharing feed-forward model.
class TextClassificationModel {
public:
+ // Represents a result of Annotate call.
+ struct AnnotatedSpan {
+ // Unicode codepoint indices in the input string.
+ CodepointSpan span = {kInvalidIndex, kInvalidIndex};
+
+ // Classification result for the span.
+ std::vector<std::pair<std::string, float>> classification;
+ };
+
// Loads TextClassificationModel from given file given by an int
// file descriptor.
+ // Offset is byte a position in the file to the beginning of the model data.
+ TextClassificationModel(int fd, int offset, int size);
+
+ // Same as above but the whole file is mapped and it is assumed the model
+ // starts at offset 0.
explicit TextClassificationModel(int fd);
+ // Loads TextClassificationModel from given file.
+ explicit TextClassificationModel(const std::string& path);
+
+ // Loads TextClassificationModel from given location in memory.
+ TextClassificationModel(const void* addr, int size);
+
+ // Returns true if the model is ready for use.
+ bool IsInitialized() { return initialized_; }
+
// Bit flags for the input selection.
enum SelectionInputFlags { SELECTION_IS_URL = 0x1, SELECTION_IS_EMAIL = 0x2 };
@@ -63,17 +85,33 @@
const std::string& context, CodepointSpan click_indices,
int input_flags = 0) const;
+ // Annotates given input text. The annotations should cover the whole input
+ // context except for whitespaces, and are sorted by their position in the
+ // context string.
+ std::vector<AnnotatedSpan> Annotate(const std::string& context) const;
+
protected:
- // Removes punctuation from the beginning and end of the selection and returns
- // the new selection span.
- CodepointSpan StripPunctuation(CodepointSpan selection,
- const std::string& context) const;
+ // Initializes the model from mmap_ file.
+ void InitFromMmap();
+
+ // Extracts chunks from the context. The extraction proceeds from the center
+ // token determined by click_span and looks at relative_click_span tokens
+ // left and right around the click position.
+ // If relative_click_span == {kInvalidIndex, kInvalidIndex} then the whole
+ // context is considered, regardless of the click_span.
+ // Returns the chunks sorted by their position in the context string.
+ std::vector<CodepointSpan> Chunk(const std::string& context,
+ CodepointSpan click_span,
+ TokenSpan relative_click_span) const;
// During evaluation we need access to the feature processor.
FeatureProcessor* SelectionFeatureProcessor() const {
return selection_feature_processor_.get();
}
+ void InitializeSharingRegexPatterns(
+ const std::vector<SharingModelOptions::RegexPattern>& patterns);
+
// Collection name when url hint is accepted.
const std::string kUrlHintCollection = "url";
@@ -90,7 +128,14 @@
SharingModelOptions sharing_options_;
private:
- bool LoadModels(const nlp_core::MmapHandle& mmap_handle);
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
+ struct CompiledRegexPattern {
+ std::string collection_name;
+ std::unique_ptr<icu::RegexPattern> pattern;
+ };
+#endif
+
+ bool LoadModels(const void* addr, int size);
nlp_core::EmbeddingNetwork::Vector InferInternal(
const std::string& context, CodepointSpan span,
@@ -108,8 +153,8 @@
CodepointSpan SuggestSelectionSymmetrical(const std::string& context,
CodepointSpan click_indices) const;
- bool initialized_;
- nlp_core::ScopedMmap mmap_;
+ bool initialized_ = false;
+ std::unique_ptr<nlp_core::ScopedMmap> mmap_;
std::unique_ptr<ModelParams> selection_params_;
std::unique_ptr<FeatureProcessor> selection_feature_processor_;
std::unique_ptr<nlp_core::EmbeddingNetwork> selection_network_;
@@ -118,14 +163,34 @@
std::unique_ptr<ModelParams> sharing_params_;
std::unique_ptr<nlp_core::EmbeddingNetwork> sharing_network_;
FeatureVectorFn sharing_feature_fn_;
-
- std::set<int> punctuation_to_strip_;
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
+ std::vector<CompiledRegexPattern> regex_patterns_;
+#endif
};
+// If the first or the last codepoint of the given span is a bracket, the
+// bracket is stripped if the span does not contain its corresponding paired
+// version.
+CodepointSpan StripUnpairedBrackets(const std::string& context,
+ CodepointSpan span);
+
// Parses the merged image given as a file descriptor, and reads
// the ModelOptions proto from the selection model.
bool ReadSelectionModelOptions(int fd, ModelOptions* model_options);
+// Pretty-printing function for TextClassificationModel::AnnotatedSpan.
+inline std::ostream& operator<<(
+ std::ostream& os, const TextClassificationModel::AnnotatedSpan& span) {
+ std::string best_class;
+ float best_score = -1;
+ if (!span.classification.empty()) {
+ best_class = span.classification[0].first;
+ best_score = span.classification[0].second;
+ }
+ return os << "Span(" << span.span.first << ", " << span.span.second << ", "
+ << best_class << ", " << best_score << ")";
+}
+
} // namespace libtextclassifier
#endif // LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_
diff --git a/smartselect/text-classification-model.proto b/smartselect/text-classification-model.proto
index b5b0287..315e849 100644
--- a/smartselect/text-classification-model.proto
+++ b/smartselect/text-classification-model.proto
@@ -37,10 +37,7 @@
message SelectionModelOptions {
// A list of Unicode codepoints to strip from predicted selections.
- repeated int32 punctuation_to_strip = 1;
-
- // Whether to strip punctuation after the selection is made.
- optional bool strip_punctuation = 2;
+ repeated int32 deprecated_punctuation_to_strip = 1;
// Enforce symmetrical selections.
optional bool enforce_symmetry = 3;
@@ -48,6 +45,14 @@
// Number of inferences made around the click position (to one side), for
// enforcing symmetry.
optional int32 symmetry_context_size = 4;
+
+ // If true, before the selection is returned, the unpaired brackets contained
+ // in the predicted selection are stripped from the both selection ends.
+ // The bracket codepoints are defined in the Unicode standard:
+ // http://www.unicode.org/Public/UNIDATA/BidiBrackets.txt
+ optional bool strip_unpaired_brackets = 5 [default = true];
+
+ reserved 2;
}
message SharingModelOptions {
@@ -60,8 +65,19 @@
// Limits for phone numbers.
optional int32 phone_min_num_digits = 3 [default = 7];
optional int32 phone_max_num_digits = 4 [default = 15];
+
+ // List of regular expression matchers to check.
+ message RegexPattern {
+ // The name of the collection of a match.
+ optional string collection_name = 1;
+
+ // The pattern to check.
+ optional string pattern = 2;
+ }
+ repeated RegexPattern regex_pattern = 5;
}
+// Next ID: 41
message FeatureProcessorOptions {
// Number of buckets used for hashing charactergrams.
optional int32 num_buckets = 1 [default = -1];
@@ -193,7 +209,18 @@
[default = INTERNAL_TOKENIZER];
optional bool icu_preserve_whitespace_tokens = 31 [default = false];
- reserved 7, 11, 12, 17, 26, 27, 28, 29, 32;
+ // List of codepoints that will be stripped from beginning and end of
+ // predicted spans.
+ repeated int32 ignored_span_boundary_codepoints = 36;
+
+ reserved 7, 11, 12, 26, 27, 28, 29, 32, 35, 39, 40;
+
+ // List of allowed charactergrams. The extracted charactergrams are filtered
+ // using this list, and charactergrams that are not present are interpreted as
+ // out-of-vocabulary.
+ // If no allowed_chargrams are specified, all charactergrams are allowed.
+ // The field is typed as bytes type to allow non-UTF8 chargrams.
+ repeated bytes allowed_chargrams = 38;
};
extend nlp_core.EmbeddingNetworkProto {
diff --git a/tests/text-classification-model_test.cc b/smartselect/text-classification-model_test.cc
similarity index 70%
rename from tests/text-classification-model_test.cc
rename to smartselect/text-classification-model_test.cc
index ed00876..5550e53 100644
--- a/tests/text-classification-model_test.cc
+++ b/smartselect/text-classification-model_test.cc
@@ -18,19 +18,41 @@
#include <fcntl.h>
#include <stdio.h>
+#include <fstream>
+#include <iostream>
#include <memory>
#include <string>
-#include "base.h"
#include "gtest/gtest.h"
namespace libtextclassifier {
namespace {
+std::string ReadFile(const std::string& file_name) {
+ std::ifstream file_stream(file_name);
+ return std::string(std::istreambuf_iterator<char>(file_stream), {});
+}
+
std::string GetModelPath() {
return TEST_DATA_DIR "smartselection.model";
}
+std::string GetURLRegexPath() {
+ return TEST_DATA_DIR "regex_url.txt";
+}
+
+std::string GetEmailRegexPath() {
+ return TEST_DATA_DIR "regex_email.txt";
+}
+
+TEST(TextClassificationModelTest, StripUnpairedBrackets) {
+ // Stripping brackets strip brackets from length 1 bracket only selections.
+ EXPECT_EQ(StripUnpairedBrackets("call me at ) today", {11, 12}),
+ std::make_pair(12, 12));
+ EXPECT_EQ(StripUnpairedBrackets("call me at ( today", {11, 12}),
+ std::make_pair(12, 12));
+}
+
TEST(TextClassificationModelTest, ReadModelOptions) {
const std::string model_path = GetModelPath();
int fd = open(model_path.c_str(), O_RDONLY);
@@ -63,6 +85,29 @@
// Single word.
EXPECT_EQ(std::make_pair(0, 4), model->SuggestSelection("asdf", {0, 4}));
+
+ EXPECT_EQ(model->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
+ std::make_pair(11, 23));
+
+ // Unpaired bracket stripping.
+ EXPECT_EQ(
+ model->SuggestSelection("call me at (857) 225 3556 today", {11, 16}),
+ std::make_pair(11, 25));
+ EXPECT_EQ(model->SuggestSelection("call me at (857 225 3556 today", {11, 15}),
+ std::make_pair(12, 24));
+ EXPECT_EQ(model->SuggestSelection("call me at 857 225 3556) today", {11, 14}),
+ std::make_pair(11, 23));
+ EXPECT_EQ(
+ model->SuggestSelection("call me at )857 225 3556( today", {11, 15}),
+ std::make_pair(12, 24));
+
+ // If the resulting selection would be empty, the original span is returned.
+ EXPECT_EQ(model->SuggestSelection("call me at )( today", {11, 13}),
+ std::make_pair(11, 13));
+ EXPECT_EQ(model->SuggestSelection("call me at ( today", {11, 12}),
+ std::make_pair(11, 12));
+ EXPECT_EQ(model->SuggestSelection("call me at ) today", {11, 12}),
+ std::make_pair(11, 12));
}
TEST(TextClassificationModelTest, SuggestSelectionsAreSymmetric) {
@@ -140,7 +185,7 @@
explicit TestingTextClassificationModel(int fd)
: libtextclassifier::TextClassificationModel(fd) {}
- using libtextclassifier::TextClassificationModel::StripPunctuation;
+ using TextClassificationModel::InitializeSharingRegexPatterns;
void DisableClassificationHints() {
sharing_options_.set_always_accept_url_hint(false);
@@ -148,29 +193,6 @@
}
};
-TEST(TextClassificationModelTest, StripPunctuation) {
- const std::string model_path = GetModelPath();
- int fd = open(model_path.c_str(), O_RDONLY);
- std::unique_ptr<TestingTextClassificationModel> model(
- new TestingTextClassificationModel(fd));
- close(fd);
-
- EXPECT_EQ(std::make_pair(3, 10),
- model->StripPunctuation({0, 10}, ".,-abcd.()"));
- EXPECT_EQ(std::make_pair(0, 6), model->StripPunctuation({0, 6}, "(abcd)"));
- EXPECT_EQ(std::make_pair(1, 5), model->StripPunctuation({0, 6}, "[abcd]"));
- EXPECT_EQ(std::make_pair(1, 5), model->StripPunctuation({0, 6}, "{abcd}"));
-
- // Empty result.
- EXPECT_EQ(std::make_pair(0, 0), model->StripPunctuation({0, 1}, "&"));
- EXPECT_EQ(std::make_pair(0, 0), model->StripPunctuation({0, 4}, "&-,}"));
-
- // Invalid indices
- EXPECT_EQ(std::make_pair(-1, 523), model->StripPunctuation({-1, 523}, "a"));
- EXPECT_EQ(std::make_pair(-1, -1), model->StripPunctuation({-1, -1}, "a"));
- EXPECT_EQ(std::make_pair(0, -1), model->StripPunctuation({0, -1}, "a"));
-}
-
TEST(TextClassificationModelTest, SuggestSelectionNoCrashWithJunk) {
const std::string model_path = GetModelPath();
int fd = open(model_path.c_str(), O_RDONLY);
@@ -328,5 +350,91 @@
"phone: (123) 456 789,0001112", {7, 28}, 0)));
}
+TEST(TextClassificationModelTest, Annotate) {
+ const std::string model_path = GetModelPath();
+ int fd = open(model_path.c_str(), O_RDONLY);
+ std::unique_ptr<TestingTextClassificationModel> model(
+ new TestingTextClassificationModel(fd));
+ close(fd);
+
+ std::string test_string =
+ "& saw Barak Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225-3556.";
+ std::vector<TextClassificationModel::AnnotatedSpan> result =
+ model->Annotate(test_string);
+
+ std::vector<TextClassificationModel::AnnotatedSpan> expected;
+ expected.emplace_back();
+ expected.back().span = {0, 0};
+ expected.emplace_back();
+ expected.back().span = {2, 5};
+ expected.back().classification.push_back({"other", 1.0});
+ expected.emplace_back();
+ expected.back().span = {6, 17};
+ expected.back().classification.push_back({"other", 1.0});
+ expected.emplace_back();
+ expected.back().span = {18, 23};
+ expected.back().classification.push_back({"other", 1.0});
+ expected.emplace_back();
+ expected.back().span = {24, 24};
+ expected.emplace_back();
+ expected.back().span = {27, 54};
+ expected.back().classification.push_back({"address", 1.0});
+ expected.emplace_back();
+ expected.back().span = {55, 58};
+ expected.back().classification.push_back({"other", 1.0});
+ expected.emplace_back();
+ expected.back().span = {59, 61};
+ expected.back().classification.push_back({"other", 1.0});
+ expected.emplace_back();
+ expected.back().span = {62, 74};
+ expected.back().classification.push_back({"other", 1.0});
+ expected.emplace_back();
+ expected.back().span = {75, 77};
+ expected.back().classification.push_back({"other", 1.0});
+ expected.emplace_back();
+ expected.back().span = {78, 90};
+ expected.back().classification.push_back({"phone", 1.0});
+
+ EXPECT_EQ(result.size(), expected.size());
+ for (int i = 0; i < expected.size(); ++i) {
+ EXPECT_EQ(result[i].span, expected[i].span) << result[i];
+ if (!expected[i].classification.empty()) {
+ EXPECT_GT(result[i].classification.size(), 0);
+ EXPECT_EQ(result[i].classification[0].first,
+ expected[i].classification[0].first)
+ << result[i];
+ }
+ }
+}
+
+TEST(TextClassificationModelTest, URLEmailRegex) {
+ const std::string model_path = GetModelPath();
+ int fd = open(model_path.c_str(), O_RDONLY);
+ std::unique_ptr<TestingTextClassificationModel> model(
+ new TestingTextClassificationModel(fd));
+ close(fd);
+
+ SharingModelOptions options;
+ SharingModelOptions::RegexPattern* email_pattern =
+ options.add_regex_pattern();
+ email_pattern->set_collection_name("email");
+ email_pattern->set_pattern(ReadFile(GetEmailRegexPath()));
+ SharingModelOptions::RegexPattern* url_pattern = options.add_regex_pattern();
+ url_pattern->set_collection_name("url");
+ url_pattern->set_pattern(ReadFile(GetURLRegexPath()));
+
+ // TODO(b/69538802): Modify directly the model image instead.
+ model->InitializeSharingRegexPatterns(
+ {options.regex_pattern().begin(), options.regex_pattern().end()});
+
+ EXPECT_EQ("url", FindBestResult(model->ClassifyText(
+ "Visit www.google.com every today!", {6, 20})));
+ EXPECT_EQ("email", FindBestResult(model->ClassifyText(
+ "My email: asdf@something.cz", {10, 27})));
+ EXPECT_EQ("url", FindBestResult(model->ClassifyText(
+ "Login: http://asdf@something.cz", {7, 31})));
+}
+
} // namespace
} // namespace libtextclassifier
diff --git a/smartselect/token-feature-extractor.cc b/smartselect/token-feature-extractor.cc
index 479be41..6afd951 100644
--- a/smartselect/token-feature-extractor.cc
+++ b/smartselect/token-feature-extractor.cc
@@ -16,14 +16,17 @@
#include "smartselect/token-feature-extractor.h"
+#include <cctype>
#include <string>
#include "util/base/logging.h"
#include "util/hash/farmhash.h"
#include "util/strings/stringpiece.h"
#include "util/utf8/unicodetext.h"
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
#include "unicode/regex.h"
#include "unicode/uchar.h"
+#endif
namespace libtextclassifier {
@@ -47,6 +50,7 @@
return copy;
}
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
void RemapTokenUnicode(const std::string& token,
const TokenFeatureExtractorOptions& options,
UnicodeText* remapped) {
@@ -70,12 +74,14 @@
icu_string.toUTF8String(utf8_str);
remapped->CopyUTF8(utf8_str.data(), utf8_str.length());
}
+#endif
} // namespace
TokenFeatureExtractor::TokenFeatureExtractor(
const TokenFeatureExtractorOptions& options)
: options_(options) {
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
UErrorCode status;
for (const std::string& pattern : options.regexp_features) {
status = U_ZERO_ERROR;
@@ -87,10 +93,44 @@
TC_LOG(WARNING) << "Failed to load pattern" << pattern;
}
}
+#else
+ bool found_unsupported_regexp_features = false;
+ for (const std::string& pattern : options.regexp_features) {
+ // A temporary solution to support this specific regexp pattern without
+ // adding too much binary size.
+ if (pattern == "^[^a-z]*$") {
+ enable_all_caps_feature_ = true;
+ } else {
+ found_unsupported_regexp_features = true;
+ }
+ }
+ if (found_unsupported_regexp_features) {
+ TC_LOG(WARNING) << "ICU not supported regexp features ignored.";
+ }
+#endif
}
int TokenFeatureExtractor::HashToken(StringPiece token) const {
- return tcfarmhash::Fingerprint64(token) % options_.num_buckets;
+ if (options_.allowed_chargrams.empty()) {
+ return tcfarmhash::Fingerprint64(token) % options_.num_buckets;
+ } else {
+ // Padding and out-of-vocabulary tokens have extra buckets reserved because
+ // they are special and important tokens, and we don't want them to share
+ // embedding with other charactergrams.
+ // TODO(zilka): Experimentally verify.
+ const int kNumExtraBuckets = 2;
+ const std::string token_string = token.ToString();
+ if (token_string == "<PAD>") {
+ return 1;
+ } else if (options_.allowed_chargrams.find(token_string) ==
+ options_.allowed_chargrams.end()) {
+ return 0; // Out-of-vocabulary.
+ } else {
+ return (tcfarmhash::Fingerprint64(token) %
+ (options_.num_buckets - kNumExtraBuckets)) +
+ kNumExtraBuckets;
+ }
+ }
}
std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeatures(
@@ -126,19 +166,23 @@
// Upper-bound the number of charactergram extracted to avoid resizing.
result.reserve(options_.chargram_orders.size() * feature_word.size());
- // Generate the character-grams.
- for (int chargram_order : options_.chargram_orders) {
- if (chargram_order == 1) {
- for (int i = 1; i < feature_word.size() - 1; ++i) {
- result.push_back(
- HashToken(StringPiece(feature_word, /*offset=*/i, /*len=*/1)));
- }
- } else {
- for (int i = 0;
- i < static_cast<int>(feature_word.size()) - chargram_order + 1;
- ++i) {
- result.push_back(HashToken(
- StringPiece(feature_word, /*offset=*/i, /*len=*/chargram_order)));
+ if (options_.chargram_orders.empty()) {
+ result.push_back(HashToken(feature_word));
+ } else {
+ // Generate the character-grams.
+ for (int chargram_order : options_.chargram_orders) {
+ if (chargram_order == 1) {
+ for (int i = 1; i < feature_word.size() - 1; ++i) {
+ result.push_back(
+ HashToken(StringPiece(feature_word, /*offset=*/i, /*len=*/1)));
+ }
+ } else {
+ for (int i = 0;
+ i < static_cast<int>(feature_word.size()) - chargram_order + 1;
+ ++i) {
+ result.push_back(HashToken(StringPiece(feature_word, /*offset=*/i,
+ /*len=*/chargram_order)));
+ }
}
}
}
@@ -148,6 +192,7 @@
std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesUnicode(
const Token& token) const {
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
std::vector<int> result;
if (token.is_padding || token.value.empty()) {
result.push_back(HashToken("<PAD>"));
@@ -186,39 +231,47 @@
// Upper-bound the number of charactergram extracted to avoid resizing.
result.reserve(options_.chargram_orders.size() * feature_word.size());
- // Generate the character-grams.
- for (int chargram_order : options_.chargram_orders) {
- UnicodeText::const_iterator it_start = feature_word_unicode.begin();
- UnicodeText::const_iterator it_end = feature_word_unicode.end();
- if (chargram_order == 1) {
- ++it_start;
- --it_end;
- }
-
- UnicodeText::const_iterator it_chargram_start = it_start;
- UnicodeText::const_iterator it_chargram_end = it_start;
- bool chargram_is_complete = true;
- for (int i = 0; i < chargram_order; ++i) {
- if (it_chargram_end == it_end) {
- chargram_is_complete = false;
- break;
+ if (options_.chargram_orders.empty()) {
+ result.push_back(HashToken(feature_word));
+ } else {
+ // Generate the character-grams.
+ for (int chargram_order : options_.chargram_orders) {
+ UnicodeText::const_iterator it_start = feature_word_unicode.begin();
+ UnicodeText::const_iterator it_end = feature_word_unicode.end();
+ if (chargram_order == 1) {
+ ++it_start;
+ --it_end;
}
- ++it_chargram_end;
- }
- if (!chargram_is_complete) {
- continue;
- }
- for (; it_chargram_end <= it_end;
- ++it_chargram_start, ++it_chargram_end) {
- const int length_bytes =
- it_chargram_end.utf8_data() - it_chargram_start.utf8_data();
- result.push_back(HashToken(
- StringPiece(it_chargram_start.utf8_data(), length_bytes)));
+ UnicodeText::const_iterator it_chargram_start = it_start;
+ UnicodeText::const_iterator it_chargram_end = it_start;
+ bool chargram_is_complete = true;
+ for (int i = 0; i < chargram_order; ++i) {
+ if (it_chargram_end == it_end) {
+ chargram_is_complete = false;
+ break;
+ }
+ ++it_chargram_end;
+ }
+ if (!chargram_is_complete) {
+ continue;
+ }
+
+ for (; it_chargram_end <= it_end;
+ ++it_chargram_start, ++it_chargram_end) {
+ const int length_bytes =
+ it_chargram_end.utf8_data() - it_chargram_start.utf8_data();
+ result.push_back(HashToken(
+ StringPiece(it_chargram_start.utf8_data(), length_bytes)));
+ }
}
}
}
return result;
+#else
+ TC_LOG(WARNING) << "ICU not supported. No feature extracted.";
+ return {};
+#endif
}
bool TokenFeatureExtractor::Extract(const Token& token, bool is_in_span,
@@ -234,7 +287,14 @@
if (options_.unicode_aware_features) {
UnicodeText token_unicode =
UTF8ToUnicodeText(token.value, /*do_copy=*/false);
- if (!token.value.empty() && u_isupper(*token_unicode.begin())) {
+ bool is_upper;
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
+ is_upper = u_isupper(*token_unicode.begin());
+#else
+ TC_LOG(WARNING) << "Using non-unicode isupper because ICU is disabled.";
+ is_upper = isupper(*token_unicode.begin());
+#endif
+ if (!token.value.empty() && is_upper) {
dense_features->push_back(1.0);
} else {
dense_features->push_back(-1.0);
@@ -260,6 +320,7 @@
}
}
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
// Add regexp features.
if (!regex_patterns_.empty()) {
icu::UnicodeString unicode_str(token.value.c_str(), token.value.size(),
@@ -281,6 +342,23 @@
}
}
}
+#else
+ if (enable_all_caps_feature_) {
+ bool is_all_caps = true;
+ for (const char character_byte : token.value) {
+ if (islower(character_byte)) {
+ is_all_caps = false;
+ break;
+ }
+ }
+ if (is_all_caps) {
+ dense_features->push_back(1.0);
+ } else {
+ dense_features->push_back(-1.0);
+ }
+ }
+#endif
+
return true;
}
diff --git a/smartselect/token-feature-extractor.h b/smartselect/token-feature-extractor.h
index 8287fbd..5afeca4 100644
--- a/smartselect/token-feature-extractor.h
+++ b/smartselect/token-feature-extractor.h
@@ -18,12 +18,14 @@
#define LIBTEXTCLASSIFIER_SMARTSELECT_TOKEN_FEATURE_EXTRACTOR_H_
#include <memory>
+#include <unordered_set>
#include <vector>
-#include "base.h"
#include "smartselect/types.h"
#include "util/strings/stringpiece.h"
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
#include "unicode/regex.h"
+#endif
namespace libtextclassifier {
@@ -55,6 +57,12 @@
// Maximum length of a word.
int max_word_length = 20;
+
+ // List of allowed charactergrams. The extracted charactergrams are filtered
+ // using this list, and charactergrams that are not present are interpreted as
+ // out-of-vocabulary.
+ // If no allowed_chargrams are specified, all charactergrams are allowed.
+ std::unordered_set<std::string> allowed_chargrams;
};
class TokenFeatureExtractor {
@@ -73,8 +81,16 @@
std::vector<float>* dense_features) const;
int DenseFeaturesCount() const {
- return options_.extract_case_feature +
- options_.extract_selection_mask_feature + regex_patterns_.size();
+ int feature_count =
+ options_.extract_case_feature + options_.extract_selection_mask_feature;
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
+ feature_count += regex_patterns_.size();
+#else
+ if (enable_all_caps_feature_) {
+ feature_count += 1;
+ }
+#endif
+ return feature_count;
}
protected:
@@ -94,8 +110,11 @@
private:
TokenFeatureExtractorOptions options_;
-
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
std::vector<std::unique_ptr<icu::RegexPattern>> regex_patterns_;
+#else
+ bool enable_all_caps_feature_ = false;
+#endif
};
} // namespace libtextclassifier
diff --git a/tests/token-feature-extractor_test.cc b/smartselect/token-feature-extractor_test.cc
similarity index 76%
rename from tests/token-feature-extractor_test.cc
rename to smartselect/token-feature-extractor_test.cc
index c85ba50..4b635fd 100644
--- a/tests/token-feature-extractor_test.cc
+++ b/smartselect/token-feature-extractor_test.cc
@@ -98,6 +98,35 @@
EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
}
+TEST(TokenFeatureExtractorTest, ExtractAsciiNoChargrams) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = false;
+ options.extract_selection_mask_feature = true;
+ TestingTokenFeatureExtractor extractor(options);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+
+ extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({extractor.HashToken("^Hello$")}));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
+
+ sparse_features.clear();
+ dense_features.clear();
+ extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({extractor.HashToken("^world!$")}));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
+}
+
TEST(TokenFeatureExtractorTest, ExtractUnicode) {
TokenFeatureExtractorOptions options;
options.num_buckets = 1000;
@@ -168,6 +197,36 @@
EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
}
+TEST(TokenFeatureExtractorTest, ExtractUnicodeNoChargrams) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = true;
+ options.extract_selection_mask_feature = true;
+ TestingTokenFeatureExtractor extractor(options);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+
+ extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({extractor.HashToken("^Hělló$")}));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
+
+ sparse_features.clear();
+ dense_features.clear();
+ extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features, testing::ElementsAreArray({
+ extractor.HashToken("^world!$"),
+ }));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
+}
+
TEST(TokenFeatureExtractorTest, ICUCaseFeature) {
TokenFeatureExtractorOptions options;
options.num_buckets = 1000;
@@ -400,5 +459,85 @@
EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
}
+TEST(TokenFeatureExtractorTest, ExtractFiltered) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2, 3};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = false;
+ options.extract_selection_mask_feature = true;
+ options.allowed_chargrams.insert("^H");
+ options.allowed_chargrams.insert("ll");
+ options.allowed_chargrams.insert("llo");
+ options.allowed_chargrams.insert("w");
+ options.allowed_chargrams.insert("!");
+ options.allowed_chargrams.insert("\xc4"); // UTF8 control character.
+
+ TestingTokenFeatureExtractor extractor(options);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+
+ extractor.Extract(Token{"Hěllo", 0, 5}, true, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({
+ // clang-format off
+ 0,
+ extractor.HashToken("\xc4"),
+ 0,
+ 0,
+ 0,
+ 0,
+ extractor.HashToken("^H"),
+ 0,
+ 0,
+ 0,
+ extractor.HashToken("ll"),
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ extractor.HashToken("llo"),
+ 0
+ // clang-format on
+ }));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
+
+ sparse_features.clear();
+ dense_features.clear();
+ extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features, testing::ElementsAreArray({
+ // clang-format off
+ extractor.HashToken("w"),
+ 0,
+ 0,
+ 0,
+ 0,
+ extractor.HashToken("!"),
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ // clang-format on
+ }));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
+ EXPECT_EQ(extractor.HashToken("<PAD>"), 1);
+}
+
} // namespace
} // namespace libtextclassifier
diff --git a/smartselect/tokenizer.cc b/smartselect/tokenizer.cc
index 2093fde..2489a61 100644
--- a/smartselect/tokenizer.cc
+++ b/smartselect/tokenizer.cc
@@ -16,49 +16,42 @@
#include "smartselect/tokenizer.h"
+#include <algorithm>
+
#include "util/strings/utf8.h"
#include "util/utf8/unicodetext.h"
namespace libtextclassifier {
-void Tokenizer::PrepareTokenizationCodepointRanges(
- const std::vector<TokenizationCodepointRange>& codepoint_range_configs) {
- codepoint_ranges_.clear();
- codepoint_ranges_.reserve(codepoint_range_configs.size());
- for (const TokenizationCodepointRange& range : codepoint_range_configs) {
- codepoint_ranges_.push_back(
- CodepointRange(range.start(), range.end(), range.role()));
- }
-
+Tokenizer::Tokenizer(
+ const std::vector<TokenizationCodepointRange>& codepoint_ranges)
+ : codepoint_ranges_(codepoint_ranges) {
std::sort(codepoint_ranges_.begin(), codepoint_ranges_.end(),
- [](const CodepointRange& a, const CodepointRange& b) {
- return a.start < b.start;
+ [](const TokenizationCodepointRange& a,
+ const TokenizationCodepointRange& b) {
+ return a.start() < b.start();
});
}
TokenizationCodepointRange::Role Tokenizer::FindTokenizationRole(
int codepoint) const {
- auto it = std::lower_bound(codepoint_ranges_.begin(), codepoint_ranges_.end(),
- codepoint,
- [](const CodepointRange& range, int codepoint) {
- // This function compares range with the
- // codepoint for the purpose of finding the first
- // greater or equal range. Because of the use of
- // std::lower_bound it needs to return true when
- // range < codepoint; the first time it will
- // return false the lower bound is found and
- // returned.
- //
- // It might seem weird that the condition is
- // range.end <= codepoint here but when codepoint
- // == range.end it means it's actually just
- // outside of the range, thus the range is less
- // than the codepoint.
- return range.end <= codepoint;
- });
- if (it != codepoint_ranges_.end() && it->start <= codepoint &&
- it->end > codepoint) {
- return it->role;
+ auto it = std::lower_bound(
+ codepoint_ranges_.begin(), codepoint_ranges_.end(), codepoint,
+ [](const TokenizationCodepointRange& range, int codepoint) {
+ // This function compares range with the codepoint for the purpose of
+ // finding the first greater or equal range. Because of the use of
+ // std::lower_bound it needs to return true when range < codepoint;
+ // the first time it will return false the lower bound is found and
+ // returned.
+ //
+ // It might seem weird that the condition is range.end <= codepoint
+ // here but when codepoint == range.end it means it's actually just
+ // outside of the range, thus the range is less than the codepoint.
+ return range.end() <= codepoint;
+ });
+ if (it != codepoint_ranges_.end() && it->start() <= codepoint &&
+ it->end() > codepoint) {
+ return it->role();
} else {
return TokenizationCodepointRange::DEFAULT_ROLE;
}
diff --git a/smartselect/tokenizer.h b/smartselect/tokenizer.h
index 897f7c4..4eb78f9 100644
--- a/smartselect/tokenizer.h
+++ b/smartselect/tokenizer.h
@@ -22,7 +22,6 @@
#include "smartselect/tokenizer.pb.h"
#include "smartselect/types.h"
-#include "util/base/integral_types.h"
namespace libtextclassifier {
@@ -31,29 +30,12 @@
class Tokenizer {
public:
explicit Tokenizer(
- const std::vector<TokenizationCodepointRange>& codepoint_range_configs) {
- PrepareTokenizationCodepointRanges(codepoint_range_configs);
- }
+ const std::vector<TokenizationCodepointRange>& codepoint_ranges);
// Tokenizes the input string using the selected tokenization method.
std::vector<Token> Tokenize(const std::string& utf8_text) const;
protected:
- // Represents a codepoint range [start, end) with its role for tokenization.
- struct CodepointRange {
- int32 start;
- int32 end;
- TokenizationCodepointRange::Role role;
-
- CodepointRange(int32 arg_start, int32 arg_end,
- TokenizationCodepointRange::Role arg_role)
- : start(arg_start), end(arg_end), role(arg_role) {}
- };
-
- // Prepares tokenization codepoint ranges for use in tokenization.
- void PrepareTokenizationCodepointRanges(
- const std::vector<TokenizationCodepointRange>& codepoint_range_configs);
-
// Finds the tokenization role for given codepoint.
// If the character is not found returns DEFAULT_ROLE.
// Internally uses binary search so should be O(log(# of codepoint_ranges)).
@@ -62,7 +44,7 @@
private:
// Codepoint ranges that determine how different codepoints are tokenized.
// The ranges must not overlap.
- std::vector<CodepointRange> codepoint_ranges_;
+ std::vector<TokenizationCodepointRange> codepoint_ranges_;
};
} // namespace libtextclassifier
diff --git a/tests/tokenizer_test.cc b/smartselect/tokenizer_test.cc
similarity index 100%
rename from tests/tokenizer_test.cc
rename to smartselect/tokenizer_test.cc
diff --git a/textclassifier_jni.cc b/textclassifier_jni.cc
index 8d64d87..8740f4c 100644
--- a/textclassifier_jni.cc
+++ b/textclassifier_jni.cc
@@ -23,9 +23,10 @@
#include "lang_id/lang-id.h"
#include "smartselect/text-classification-model.h"
+#include "util/java/scoped_local_ref.h"
-using libtextclassifier::TextClassificationModel;
using libtextclassifier::ModelOptions;
+using libtextclassifier::TextClassificationModel;
using libtextclassifier::nlp_core::lang_id::LangId;
namespace {
@@ -38,6 +39,11 @@
}
jclass string_class = env->FindClass("java/lang/String");
+ if (!string_class) {
+ TC_LOG(ERROR) << "Can't find String class";
+ return false;
+ }
+
jmethodID get_bytes_id =
env->GetMethodID(string_class, "getBytes", "(Ljava/lang/String;)[B");
@@ -69,6 +75,11 @@
JNIEnv* env, const std::string& result_class_name,
const std::vector<std::pair<std::string, float>>& classification_result) {
jclass result_class = env->FindClass(result_class_name.c_str());
+ if (!result_class) {
+ TC_LOG(ERROR) << "Couldn't find result class: " << result_class_name;
+ return nullptr;
+ }
+
jmethodID result_class_constructor =
env->GetMethodID(result_class, "<init>", "(Ljava/lang/String;F)V");
@@ -155,22 +166,64 @@
} // namespace libtextclassifier
-using libtextclassifier::ConvertIndicesUTF8ToBMP;
-using libtextclassifier::ConvertIndicesBMPToUTF8;
using libtextclassifier::CodepointSpan;
+using libtextclassifier::ConvertIndicesBMPToUTF8;
+using libtextclassifier::ConvertIndicesUTF8ToBMP;
+using libtextclassifier::ScopedLocalRef;
-JNIEXPORT jlong JNICALL
-Java_android_view_textclassifier_SmartSelection_nativeNew(JNIEnv* env,
- jobject thiz,
- jint fd) {
+JNI_METHOD(jlong, SmartSelection, nativeNew)
+(JNIEnv* env, jobject thiz, jint fd) {
TextClassificationModel* model = new TextClassificationModel(fd);
return reinterpret_cast<jlong>(model);
}
-JNIEXPORT jintArray JNICALL
-Java_android_view_textclassifier_SmartSelection_nativeSuggest(
- JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
- jint selection_end) {
+JNI_METHOD(jlong, SmartSelection, nativeNewFromPath)
+(JNIEnv* env, jobject thiz, jstring path) {
+ const std::string path_str = ToStlString(env, path);
+ TextClassificationModel* model = new TextClassificationModel(path_str);
+ return reinterpret_cast<jlong>(model);
+}
+
+JNI_METHOD(jlong, SmartSelection, nativeNewFromAssetFileDescriptor)
+(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
+ // Get system-level file descriptor from AssetFileDescriptor.
+ ScopedLocalRef<jclass> afd_class(
+ env->FindClass("android/content/res/AssetFileDescriptor"), env);
+ if (afd_class == nullptr) {
+ TC_LOG(ERROR) << "Couln't find AssetFileDescriptor.";
+ return reinterpret_cast<jlong>(nullptr);
+ }
+ jmethodID afd_class_getFileDescriptor = env->GetMethodID(
+ afd_class.get(), "getFileDescriptor", "()Ljava/io/FileDescriptor;");
+ if (afd_class_getFileDescriptor == nullptr) {
+ TC_LOG(ERROR) << "Couln't find getFileDescriptor.";
+ return reinterpret_cast<jlong>(nullptr);
+ }
+
+ ScopedLocalRef<jclass> fd_class(env->FindClass("java/io/FileDescriptor"),
+ env);
+ if (fd_class == nullptr) {
+ TC_LOG(ERROR) << "Couln't find FileDescriptor.";
+ return reinterpret_cast<jlong>(nullptr);
+ }
+ jfieldID fd_class_descriptor =
+ env->GetFieldID(fd_class.get(), "descriptor", "I");
+ if (fd_class_descriptor == nullptr) {
+ TC_LOG(ERROR) << "Couln't find descriptor.";
+ return reinterpret_cast<jlong>(nullptr);
+ }
+
+ jobject bundle_jfd = env->CallObjectMethod(afd, afd_class_getFileDescriptor);
+ jint bundle_cfd = env->GetIntField(bundle_jfd, fd_class_descriptor);
+
+ TextClassificationModel* model =
+ new TextClassificationModel(bundle_cfd, offset, size);
+ return reinterpret_cast<jlong>(model);
+}
+
+JNI_METHOD(jintArray, SmartSelection, nativeSuggest)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
+ jint selection_end) {
TextClassificationModel* model =
reinterpret_cast<TextClassificationModel*>(ptr);
@@ -187,10 +240,9 @@
return result;
}
-JNIEXPORT jobjectArray JNICALL
-Java_android_view_textclassifier_SmartSelection_nativeClassifyText(
- JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
- jint selection_end, jint input_flags) {
+JNI_METHOD(jobjectArray, SmartSelection, nativeClassifyText)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
+ jint selection_end, jint input_flags) {
TextClassificationModel* ff_model =
reinterpret_cast<TextClassificationModel*>(ptr);
const std::vector<std::pair<std::string, float>> classification_result =
@@ -198,28 +250,58 @@
{selection_begin, selection_end}, input_flags);
return ScoredStringsToJObjectArray(
- env, "android/view/textclassifier/SmartSelection$ClassificationResult",
+ env, TC_PACKAGE_PATH "SmartSelection$ClassificationResult",
classification_result);
}
-JNIEXPORT void JNICALL
-Java_android_view_textclassifier_SmartSelection_nativeClose(JNIEnv* env,
- jobject thiz,
- jlong ptr) {
+JNI_METHOD(jobjectArray, SmartSelection, nativeAnnotate)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring context) {
+ TextClassificationModel* model =
+ reinterpret_cast<TextClassificationModel*>(ptr);
+ std::string context_utf8 = ToStlString(env, context);
+ std::vector<TextClassificationModel::AnnotatedSpan> annotations =
+ model->Annotate(context_utf8);
+
+ jclass result_class =
+ env->FindClass(TC_PACKAGE_PATH "SmartSelection$AnnotatedSpan");
+ if (!result_class) {
+ TC_LOG(ERROR) << "Couldn't find result class: "
+ << TC_PACKAGE_PATH "SmartSelection$AnnotatedSpan";
+ return nullptr;
+ }
+
+ jmethodID result_class_constructor = env->GetMethodID(
+ result_class, "<init>",
+ "(II[L" TC_PACKAGE_PATH "SmartSelection$ClassificationResult;)V");
+
+ jobjectArray results =
+ env->NewObjectArray(annotations.size(), result_class, nullptr);
+
+ for (int i = 0; i < annotations.size(); ++i) {
+ CodepointSpan span_bmp =
+ ConvertIndicesUTF8ToBMP(context_utf8, annotations[i].span);
+ jobject result = env->NewObject(
+ result_class, result_class_constructor,
+ static_cast<jint>(span_bmp.first), static_cast<jint>(span_bmp.second),
+ ScoredStringsToJObjectArray(
+ env, TC_PACKAGE_PATH "SmartSelection$ClassificationResult",
+ annotations[i].classification));
+ env->SetObjectArrayElement(results, i, result);
+ env->DeleteLocalRef(result);
+ }
+ env->DeleteLocalRef(result_class);
+ return results;
+}
+
+JNI_METHOD(void, SmartSelection, nativeClose)
+(JNIEnv* env, jobject thiz, jlong ptr) {
TextClassificationModel* model =
reinterpret_cast<TextClassificationModel*>(ptr);
delete model;
}
-JNIEXPORT jlong JNICALL Java_android_view_textclassifier_LangId_nativeNew(
- JNIEnv* env, jobject thiz, jint fd) {
- return reinterpret_cast<jlong>(new LangId(fd));
-}
-
-JNIEXPORT jstring JNICALL
-Java_android_view_textclassifier_SmartSelection_nativeGetLanguage(JNIEnv* env,
- jobject clazz,
- jint fd) {
+JNI_METHOD(jstring, SmartSelection, nativeGetLanguage)
+(JNIEnv* env, jobject clazz, jint fd) {
ModelOptions model_options;
if (ReadSelectionModelOptions(fd, &model_options)) {
return env->NewStringUTF(model_options.language().c_str());
@@ -228,10 +310,8 @@
}
}
-JNIEXPORT jint JNICALL
-Java_android_view_textclassifier_SmartSelection_nativeGetVersion(JNIEnv* env,
- jobject clazz,
- jint fd) {
+JNI_METHOD(jint, SmartSelection, nativeGetVersion)
+(JNIEnv* env, jobject clazz, jint fd) {
ModelOptions model_options;
if (ReadSelectionModelOptions(fd, &model_options)) {
return model_options.version();
@@ -240,28 +320,31 @@
}
}
-JNIEXPORT jobjectArray JNICALL
-Java_android_view_textclassifier_LangId_nativeFindLanguages(JNIEnv* env,
- jobject thiz,
- jlong ptr,
- jstring text) {
+#ifndef LIBTEXTCLASSIFIER_DISABLE_LANG_ID
+JNI_METHOD(jlong, LangId, nativeNew)
+(JNIEnv* env, jobject thiz, jint fd) {
+ return reinterpret_cast<jlong>(new LangId(fd));
+}
+
+JNI_METHOD(jobjectArray, LangId, nativeFindLanguages)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring text) {
LangId* lang_id = reinterpret_cast<LangId*>(ptr);
const std::vector<std::pair<std::string, float>> scored_languages =
lang_id->FindLanguages(ToStlString(env, text));
return ScoredStringsToJObjectArray(
- env, "android/view/textclassifier/LangId$ClassificationResult",
- scored_languages);
+ env, TC_PACKAGE_PATH "LangId$ClassificationResult", scored_languages);
}
-JNIEXPORT void JNICALL Java_android_view_textclassifier_LangId_nativeClose(
- JNIEnv* env, jobject thiz, jlong ptr) {
+JNI_METHOD(void, LangId, nativeClose)
+(JNIEnv* env, jobject thiz, jlong ptr) {
LangId* lang_id = reinterpret_cast<LangId*>(ptr);
delete lang_id;
}
-JNIEXPORT int JNICALL Java_android_view_textclassifier_LangId_nativeGetVersion(
- JNIEnv* env, jobject clazz, jint fd) {
+JNI_METHOD(int, LangId, nativeGetVersion)
+(JNIEnv* env, jobject clazz, jint fd) {
std::unique_ptr<LangId> lang_id(new LangId(fd));
return lang_id->version();
}
+#endif
diff --git a/textclassifier_jni.h b/textclassifier_jni.h
index 28bb444..1709ff4 100644
--- a/textclassifier_jni.h
+++ b/textclassifier_jni.h
@@ -22,56 +22,70 @@
#include "smartselect/types.h"
+#ifndef TC_PACKAGE_NAME
+#define TC_PACKAGE_NAME android_view_textclassifier
+#endif
+#ifndef TC_PACKAGE_PATH
+#define TC_PACKAGE_PATH "android/view/textclassifier/"
+#endif
+
+#define JNI_METHOD_PRIMITIVE(return_type, package_name, class_name, \
+ method_name) \
+ JNIEXPORT return_type JNICALL \
+ Java_##package_name##_##class_name##_##method_name
+
+// The indirection is needed to correctly expand the TC_PACKAGE_NAME macro.
+#define JNI_METHOD2(return_type, package_name, class_name, method_name) \
+ JNI_METHOD_PRIMITIVE(return_type, package_name, class_name, method_name)
+
+#define JNI_METHOD(return_type, class_name, method_name) \
+ JNI_METHOD2(return_type, TC_PACKAGE_NAME, class_name, method_name)
+
#ifdef __cplusplus
extern "C" {
#endif
// SmartSelection.
-JNIEXPORT jlong JNICALL
-Java_android_view_textclassifier_SmartSelection_nativeNew(JNIEnv* env,
- jobject thiz,
- jint fd);
+JNI_METHOD(jlong, SmartSelection, nativeNew)
+(JNIEnv* env, jobject thiz, jint fd);
-JNIEXPORT jintArray JNICALL
-Java_android_view_textclassifier_SmartSelection_nativeSuggest(
- JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
- jint selection_end);
+JNI_METHOD(jlong, SmartSelection, nativeNewFromPath)
+(JNIEnv* env, jobject thiz, jstring path);
-JNIEXPORT jobjectArray JNICALL
-Java_android_view_textclassifier_SmartSelection_nativeClassifyText(
- JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
- jint selection_end, jint input_flags);
+JNI_METHOD(jlong, SmartSelection, nativeNewFromAssetFileDescriptor)
+(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size);
-JNIEXPORT void JNICALL
-Java_android_view_textclassifier_SmartSelection_nativeClose(JNIEnv* env,
- jobject thiz,
- jlong ptr);
+JNI_METHOD(jintArray, SmartSelection, nativeSuggest)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
+ jint selection_end);
-JNIEXPORT jstring JNICALL
-Java_android_view_textclassifier_SmartSelection_nativeGetLanguage(JNIEnv* env,
- jobject clazz,
- jint fd);
+JNI_METHOD(jobjectArray, SmartSelection, nativeClassifyText)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
+ jint selection_end, jint input_flags);
-JNIEXPORT jint JNICALL
-Java_android_view_textclassifier_SmartSelection_nativeGetVersion(JNIEnv* env,
- jobject clazz,
- jint fd);
+JNI_METHOD(jobjectArray, SmartSelection, nativeAnnotate)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring context);
+JNI_METHOD(void, SmartSelection, nativeClose)
+(JNIEnv* env, jobject thiz, jlong ptr);
+
+JNI_METHOD(jstring, SmartSelection, nativeGetLanguage)
+(JNIEnv* env, jobject clazz, jint fd);
+
+JNI_METHOD(jint, SmartSelection, nativeGetVersion)
+(JNIEnv* env, jobject clazz, jint fd);
+
+#ifndef LIBTEXTCLASSIFIER_DISABLE_LANG_ID
// LangId.
-JNIEXPORT jlong JNICALL Java_android_view_textclassifier_LangId_nativeNew(
- JNIEnv* env, jobject thiz, jint fd);
+JNI_METHOD(jlong, LangId, nativeNew)(JNIEnv* env, jobject thiz, jint fd);
-JNIEXPORT jobjectArray JNICALL
-Java_android_view_textclassifier_LangId_nativeFindLanguages(JNIEnv* env,
- jobject thiz,
- jlong ptr,
- jstring text);
+JNI_METHOD(jobjectArray, LangId, nativeFindLanguages)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring text);
-JNIEXPORT void JNICALL Java_android_view_textclassifier_LangId_nativeClose(
- JNIEnv* env, jobject thiz, jlong ptr);
+JNI_METHOD(void, LangId, nativeClose)(JNIEnv* env, jobject thiz, jlong ptr);
-JNIEXPORT int JNICALL Java_android_view_textclassifier_LangId_nativeGetVersion(
- JNIEnv* env, jobject clazz, jint fd);
+JNI_METHOD(int, LangId, nativeGetVersion)(JNIEnv* env, jobject clazz, jint fd);
+#endif
#ifdef __cplusplus
}
diff --git a/tests/textclassifier_jni_test.cc b/textclassifier_jni_test.cc
similarity index 100%
rename from tests/textclassifier_jni_test.cc
rename to textclassifier_jni_test.cc
diff --git a/util/base/casts.h b/util/base/casts.h
index ad12ce4..805ee89 100644
--- a/util/base/casts.h
+++ b/util/base/casts.h
@@ -21,13 +21,12 @@
namespace libtextclassifier {
-// lang_id_bit_cast<Dest,Source> is a template function that implements the
-// equivalent of "*reinterpret_cast<Dest*>(&source)". We need this in
-// very low-level functions like the protobuf library and fast math
-// support.
+// bit_cast<Dest, Source> is a template function that implements the equivalent
+// of "*reinterpret_cast<Dest*>(&source)". We need this in very low-level
+// functions like fast math support.
//
// float f = 3.14159265358979;
-// int i = lang_id_bit_cast<int32>(f);
+// int i = bit_cast<int32>(f);
// // i = 0x40490fdb
//
// The classical address-casting method is:
@@ -60,9 +59,9 @@
//
// Anyways ...
//
-// lang_id_bit_cast<> calls memcpy() which is blessed by the standard,
-// especially by the example in section 3.9 . Also, of course,
-// lang_id_bit_cast<> wraps up the nasty logic in one place.
+// bit_cast<> calls memcpy() which is blessed by the standard, especially by the
+// example in section 3.9 . Also, of course, bit_cast<> wraps up the nasty
+// logic in one place.
//
// Fortunately memcpy() is very fast. In optimized mode, with a
// constant size, gcc 2.95.3, gcc 4.0.1, and msvc 7.1 produce inline
@@ -70,15 +69,14 @@
// memcpy(d,s,4) compiles to one load and one store, and memcpy(d,s,8)
// compiles to two loads and two stores.
//
-// I tested this code with gcc 2.95.3, gcc 4.0.1, icc 8.1, and msvc 7.1.
+// Mike Chastain tested this code with gcc 2.95.3, gcc 4.0.1, icc 8.1, and msvc
+// 7.1.
//
// WARNING: if Dest or Source is a non-POD type, the result of the memcpy
// is likely to surprise you.
//
// Props to Bill Gibbons for the compile time assertion technique and
// Art Komninos and Igor Tandetnik for the msvc experiments.
-//
-// -- mec 2005-10-17
template <class Dest, class Source>
inline Dest bit_cast(const Source &source) {
diff --git a/base.h b/util/base/endian.h
similarity index 78%
rename from base.h
rename to util/base/endian.h
index 6829f1c..f319f65 100644
--- a/base.h
+++ b/util/base/endian.h
@@ -14,35 +14,44 @@
* limitations under the License.
*/
-#ifndef LIBTEXTCLASSIFIER_BASE_H_
-#define LIBTEXTCLASSIFIER_BASE_H_
+#ifndef LIBTEXTCLASSIFIER_UTIL_BASE_ENDIAN_H_
+#define LIBTEXTCLASSIFIER_UTIL_BASE_ENDIAN_H_
-#include <cassert>
-#include <map>
-#include <string>
-#include <vector>
-
-#include "util/base/config.h"
#include "util/base/integral_types.h"
namespace libtextclassifier {
-#ifdef INTERNAL_BUILD
-typedef basic_string<char> bstring;
-#else
-typedef std::basic_string<char> bstring;
-#endif // INTERNAL_BUILD
-
#if defined OS_LINUX || defined OS_CYGWIN || defined OS_ANDROID || \
defined(__ANDROID__)
#include <endian.h>
+#elif defined(__APPLE__)
+#include <machine/endian.h>
+// Add linux style defines.
+#ifndef __BYTE_ORDER
+#define __BYTE_ORDER BYTE_ORDER
+#endif // __BYTE_ORDER
+#ifndef __LITTLE_ENDIAN
+#define __LITTLE_ENDIAN LITTLE_ENDIAN
+#endif // __LITTLE_ENDIAN
+#ifndef __BIG_ENDIAN
+#define __BIG_ENDIAN BIG_ENDIAN
+#endif // __BIG_ENDIAN
#endif
// The following guarantees declaration of the byte swap functions, and
// defines __BYTE_ORDER for MSVC
#if defined(__GLIBC__) || defined(__CYGWIN__)
#include <byteswap.h> // IWYU pragma: export
-
+// The following section defines the byte swap functions for OS X / iOS,
+// which does not ship with byteswap.h.
+#elif defined(__APPLE__)
+// Make sure that byte swap functions are not already defined.
+#if !defined(bswap_16)
+#include <libkern/OSByteOrder.h>
+#define bswap_16(x) OSSwapInt16(x)
+#define bswap_32(x) OSSwapInt32(x)
+#define bswap_64(x) OSSwapInt64(x)
+#endif // !defined(bswap_16)
#else
#define GG_LONGLONG(x) x##LL
#define GG_ULONGLONG(x) x##ULL
@@ -126,4 +135,4 @@
} // namespace libtextclassifier
-#endif // LIBTEXTCLASSIFIER_BASE_H_
+#endif // LIBTEXTCLASSIFIER_UTIL_BASE_ENDIAN_H_
diff --git a/util/base/logging.h b/util/base/logging.h
index b0f3c5d..dba0ed4 100644
--- a/util/base/logging.h
+++ b/util/base/logging.h
@@ -24,6 +24,23 @@
#include "util/base/logging_levels.h"
#include "util/base/port.h"
+// TC_STRIP
+namespace libtextclassifier {
+// string class that can't be instantiated. Makes sure that the code does not
+// compile when non std::string is used.
+//
+// NOTE: defined here because most files directly or transitively include this
+// file. Asking people to include a special header just to make sure they don't
+// use the unqualified string doesn't work: as that header doesn't produce any
+// immediate benefit, one can easily forget about it.
+class string {
+ public:
+ // Makes the class non-instantiable.
+ virtual ~string() = 0;
+};
+} // namespace libtextclassifier
+// TC_END_STRIP
+
namespace libtextclassifier {
namespace logging {
@@ -75,10 +92,6 @@
#define TC_CHECK_GE(x, y) TC_CHECK((x) >= (y))
#define TC_CHECK_NE(x, y) TC_CHECK((x) != (y))
-// Debug checks: a TC_DCHECK<suffix> macro should behave like TC_CHECK<suffix>
-// in debug mode an don't check / don't print anything in non-debug mode.
-#ifdef NDEBUG
-
// Pseudo-stream that "eats" the tokens <<-pumped into it, without printing
// anything.
class NullStream {
@@ -92,6 +105,11 @@
}
#define TC_NULLSTREAM ::libtextclassifier::logging::NullStream().stream()
+
+// Debug checks: a TC_DCHECK<suffix> macro should behave like TC_CHECK<suffix>
+// in debug mode an don't check / don't print anything in non-debug mode.
+#ifdef NDEBUG
+
#define TC_DCHECK(x) TC_NULLSTREAM
#define TC_DCHECK_EQ(x, y) TC_NULLSTREAM
#define TC_DCHECK_LT(x, y) TC_NULLSTREAM
@@ -113,6 +131,16 @@
#define TC_DCHECK_NE(x, y) TC_CHECK_NE(x, y)
#endif // NDEBUG
+
+#ifdef LIBTEXTCLASSIFIER_VLOG
+#define TC_VLOG(severity) \
+ ::libtextclassifier::logging::LogMessage(::libtextclassifier::logging::INFO, \
+ __FILE__, __LINE__) \
+ .stream()
+#else
+#define TC_VLOG(severity) TC_NULLSTREAM
+#endif
+
} // namespace logging
} // namespace libtextclassifier
diff --git a/util/hash/farmhash.cc b/util/hash/farmhash.cc
index 55786a9..673f45f 100644
--- a/util/hash/farmhash.cc
+++ b/util/hash/farmhash.cc
@@ -348,10 +348,7 @@
return x;
}
-} // namespace NAMESPACE_FOR_HASH_FUNCTIONS;
-
using namespace std;
-using namespace NAMESPACE_FOR_HASH_FUNCTIONS;
namespace farmhashna {
#undef Fetch
#define Fetch Fetch64
@@ -642,7 +639,7 @@
uint32_t Hash32(const char *s, size_t len) {
FARMHASH_DIE_IF_MISCONFIGURED;
- return s == NULL ? 0 : len;
+ return s == nullptr ? 0 : len;
}
uint32_t Hash32WithSeed(const char *s, size_t len, uint32_t seed) {
@@ -865,7 +862,7 @@
uint32_t Hash32(const char *s, size_t len) {
FARMHASH_DIE_IF_MISCONFIGURED;
- return s == NULL ? 0 : len;
+ return s == nullptr ? 0 : len;
}
uint32_t Hash32WithSeed(const char *s, size_t len, uint32_t seed) {
@@ -894,7 +891,7 @@
uint32_t Hash32(const char *s, size_t len) {
FARMHASH_DIE_IF_MISCONFIGURED;
- return s == NULL ? 0 : len;
+ return s == nullptr ? 0 : len;
}
uint32_t Hash32WithSeed(const char *s, size_t len, uint32_t seed) {
@@ -1407,7 +1404,6 @@
return CityHash128(s, len);
}
} // namespace farmhashcc
-namespace NAMESPACE_FOR_HASH_FUNCTIONS {
// BASIC STRING HASHING
diff --git a/util/java/scoped_local_ref.h b/util/java/scoped_local_ref.h
new file mode 100644
index 0000000..d995468
--- /dev/null
+++ b/util/java/scoped_local_ref.h
@@ -0,0 +1,65 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTIL_JAVA_SCOPED_LOCAL_REF_H_
+#define LIBTEXTCLASSIFIER_UTIL_JAVA_SCOPED_LOCAL_REF_H_
+
+#include <jni.h>
+#include <memory>
+#include <type_traits>
+
+#include "util/base/logging.h"
+
+namespace libtextclassifier {
+
+// A deleter to be used with std::unique_ptr to delete JNI local references.
+class LocalRefDeleter {
+ public:
+ // Style guide violating implicit constructor so that the LocalRefDeleter
+ // is implicitly constructed from the second argument to ScopedLocalRef.
+ LocalRefDeleter(JNIEnv* env) : env_(env) {} // NOLINT(runtime/explicit)
+
+ LocalRefDeleter(const LocalRefDeleter& orig) = default;
+
+ // Copy assignment to allow move semantics in ScopedLocalRef.
+ LocalRefDeleter& operator=(const LocalRefDeleter& rhs) {
+ // As the deleter and its state are thread-local, ensure the envs
+ // are consistent but do nothing.
+ TC_CHECK_EQ(env_, rhs.env_);
+ return *this;
+ }
+
+ // The delete operator.
+ void operator()(jobject o) const { env_->DeleteLocalRef(o); }
+
+ private:
+ // The env_ stashed to use for deletion. Thread-local, don't share!
+ JNIEnv* const env_;
+};
+
+// A smart pointer that deletes a JNI local reference when it goes out
+// of scope. Usage is:
+// ScopedLocalRef<jobject> scoped_local(env->JniFunction(), env);
+//
+// Note that this class is not thread-safe since it caches JNIEnv in
+// the deleter. Do not use the same jobject across different threads.
+template <typename T>
+using ScopedLocalRef =
+ std::unique_ptr<typename std::remove_pointer<T>::type, LocalRefDeleter>;
+
+} // namespace libtextclassifier
+
+#endif // LIBTEXTCLASSIFIER_UTIL_JAVA_SCOPED_LOCAL_REF_H_
diff --git a/tests/numbers_test.cc b/util/strings/numbers_test.cc
similarity index 100%
rename from tests/numbers_test.cc
rename to util/strings/numbers_test.cc
diff --git a/util/utf8/unicodetext.cc b/util/utf8/unicodetext.cc
index e83c890..dbab1c8 100644
--- a/util/utf8/unicodetext.cc
+++ b/util/utf8/unicodetext.cc
@@ -16,7 +16,10 @@
#include "util/utf8/unicodetext.h"
-#include "base.h"
+#include <string.h>
+
+#include <algorithm>
+
#include "util/strings/utf8.h"
namespace libtextclassifier {
@@ -108,6 +111,8 @@
void UnicodeText::clear() { repr_.clear(); }
+int UnicodeText::size() const { return std::distance(begin(), end()); }
+
std::string UnicodeText::UTF8Substring(const const_iterator& first,
const const_iterator& last) {
return std::string(first.it_, last.it_ - first.it_);
diff --git a/util/utf8/unicodetext.h b/util/utf8/unicodetext.h
index 5327383..6a21058 100644
--- a/util/utf8/unicodetext.h
+++ b/util/utf8/unicodetext.h
@@ -17,9 +17,11 @@
#ifndef LIBTEXTCLASSIFIER_UTIL_UTF8_UNICODETEXT_H_
#define LIBTEXTCLASSIFIER_UTIL_UTF8_UNICODETEXT_H_
+#include <iterator>
+#include <string>
#include <utility>
-#include "base.h"
+#include "util/base/integral_types.h"
namespace libtextclassifier {
@@ -137,6 +139,7 @@
const_iterator begin() const;
const_iterator end() const;
+ int size() const; // the number of Unicode characters (codepoints)
// x.PointToUTF8(buf,len) changes x so that it points to buf
// ("becomes an alias"). It does not take ownership or copy buf.
@@ -162,7 +165,7 @@
int capacity_;
bool ours_; // Do we own data_?
- Repr() : data_(NULL), size_(0), capacity_(0), ours_(true) {}
+ Repr() : data_(nullptr), size_(0), capacity_(0), ours_(true) {}
~Repr() {
if (ours_) delete[] data_;
}