Snap for 4402171 from 134f3faa4d7c7c8a062e7d581d23537800e6fb11 to pi-release
Change-Id: If246f1d5c3e6cb8a8c1d738e86fdd62605ba5f47
diff --git a/Android.mk b/Android.mk
index 0959572..467c5d7 100644
--- a/Android.mk
+++ b/Android.mk
@@ -64,7 +64,7 @@
LOCAL_CFLAGS += $(MY_LIBTEXTCLASSIFIER_CFLAGS)
LOCAL_STRIP_MODULE := $(LIBTEXTCLASSIFIER_STRIP_OPTS)
-LOCAL_SRC_FILES := $(filter-out tests/%,$(call all-subdir-cpp-files))
+LOCAL_SRC_FILES := $(filter-out tests/% %_test.cc,$(call all-subdir-cpp-files))
LOCAL_C_INCLUDES += $(proto_sources_dir)/proto/external/libtextclassifier
LOCAL_STATIC_LIBRARIES += libtextclassifier_protos
diff --git a/tests/embedding-feature-extractor_test.cc b/common/embedding-feature-extractor_test.cc
similarity index 100%
rename from tests/embedding-feature-extractor_test.cc
rename to common/embedding-feature-extractor_test.cc
diff --git a/tests/embedding-network_test.cc b/common/embedding-network_test.cc
similarity index 100%
rename from tests/embedding-network_test.cc
rename to common/embedding-network_test.cc
diff --git a/tests/fml-parser_test.cc b/common/fml-parser_test.cc
similarity index 100%
rename from tests/fml-parser_test.cc
rename to common/fml-parser_test.cc
diff --git a/common/little-endian-data.h b/common/little-endian-data.h
index 8441cd7..e3bc88f 100644
--- a/common/little-endian-data.h
+++ b/common/little-endian-data.h
@@ -21,7 +21,7 @@
#include <string>
#include <vector>
-#include "base.h"
+#include "util/base/endian.h"
#include "util/base/logging.h"
namespace libtextclassifier {
diff --git a/common/memory_image/data-store.cc b/common/memory_image/data-store.cc
index dcdaa82..a5f500c 100644
--- a/common/memory_image/data-store.cc
+++ b/common/memory_image/data-store.cc
@@ -29,12 +29,12 @@
}
}
-DataBlobView DataStore::GetData(const std::string &name) const {
+StringPiece DataStore::GetData(const std::string &name) const {
if (!reader_.success_status()) {
TC_LOG(ERROR) << "DataStore::GetData(" << name << ") "
<< "called on invalid DataStore; will return empty data "
<< "chunk";
- return DataBlobView();
+ return StringPiece();
}
const auto &entries = reader_.trimmed_proto().entries();
@@ -42,14 +42,14 @@
if (it == entries.end()) {
TC_LOG(ERROR) << "Unknown key: " << name
<< "; will return empty data chunk";
- return DataBlobView();
+ return StringPiece();
}
const DataStoreEntryBytes &entry_bytes = it->second;
if (!entry_bytes.has_blob_index()) {
TC_LOG(ERROR) << "DataStoreEntryBytes with no blob_index; "
<< "will return empty data chunk.";
- return DataBlobView();
+ return StringPiece();
}
int blob_index = entry_bytes.blob_index();
diff --git a/common/memory_image/data-store.h b/common/memory_image/data-store.h
index 9af5e0e..56aa4fc 100644
--- a/common/memory_image/data-store.h
+++ b/common/memory_image/data-store.h
@@ -20,7 +20,6 @@
#include <string>
#include "common/memory_image/data-store.pb.h"
-#include "common/memory_image/memory-image-common.h"
#include "common/memory_image/memory-image-reader.h"
#include "util/strings/stringpiece.h"
@@ -46,7 +45,7 @@
// If the alignment is a low power of 2 (e..g, 4, 8, or 16) and "start" passed
// to constructor corresponds to the beginning of a memory page or an address
// returned by new or malloc(), then start_addr is divisible with alignment.
- DataBlobView GetData(const std::string &name) const;
+ StringPiece GetData(const std::string &name) const;
private:
MemoryImageReader<DataStoreProto> reader_;
diff --git a/common/memory_image/embedding-network-params-from-image.h b/common/memory_image/embedding-network-params-from-image.h
index feb4817..e8c7d1e 100644
--- a/common/memory_image/embedding-network-params-from-image.h
+++ b/common/memory_image/embedding-network-params-from-image.h
@@ -22,6 +22,7 @@
#include "common/embedding-network.pb.h"
#include "common/memory_image/memory-image-reader.h"
#include "util/base/integral_types.h"
+#include "util/strings/stringpiece.h"
namespace libtextclassifier {
namespace nlp_core {
@@ -84,7 +85,7 @@
const int blob_index = trimmed_proto_.embeddings(i).is_quantized()
? (embeddings_blob_offset_ + 2 * i)
: (embeddings_blob_offset_ + i);
- DataBlobView data_blob_view = memory_reader_.data_blob_view(blob_index);
+ StringPiece data_blob_view = memory_reader_.data_blob_view(blob_index);
return data_blob_view.data();
}
@@ -104,7 +105,7 @@
// one blob with the quantized values and (immediately after it, hence the
// "+ 1") one blob with the scales.
int blob_index = embeddings_blob_offset_ + 2 * i + 1;
- DataBlobView data_blob_view = memory_reader_.data_blob_view(blob_index);
+ StringPiece data_blob_view = memory_reader_.data_blob_view(blob_index);
return reinterpret_cast<const float16 *>(data_blob_view.data());
} else {
return nullptr;
@@ -125,7 +126,7 @@
const void *hidden_weights(int i) const override {
TC_DCHECK(InRange(i, hidden_size()));
- DataBlobView data_blob_view =
+ StringPiece data_blob_view =
memory_reader_.data_blob_view(hidden_blob_offset_ + i);
return data_blob_view.data();
}
@@ -146,7 +147,7 @@
const void *hidden_bias_weights(int i) const override {
TC_DCHECK(InRange(i, hidden_bias_size()));
- DataBlobView data_blob_view =
+ StringPiece data_blob_view =
memory_reader_.data_blob_view(hidden_bias_blob_offset_ + i);
return data_blob_view.data();
}
@@ -167,7 +168,7 @@
const void *softmax_weights(int i) const override {
TC_DCHECK(InRange(i, softmax_size()));
- DataBlobView data_blob_view =
+ StringPiece data_blob_view =
memory_reader_.data_blob_view(softmax_blob_offset_ + i);
return data_blob_view.data();
}
@@ -188,7 +189,7 @@
const void *softmax_bias_weights(int i) const override {
TC_DCHECK(InRange(i, softmax_bias_size()));
- DataBlobView data_blob_view =
+ StringPiece data_blob_view =
memory_reader_.data_blob_view(softmax_bias_blob_offset_ + i);
return data_blob_view.data();
}
diff --git a/common/memory_image/in-memory-model-data.cc b/common/memory_image/in-memory-model-data.cc
index f5cda3a..acf3d86 100644
--- a/common/memory_image/in-memory-model-data.cc
+++ b/common/memory_image/in-memory-model-data.cc
@@ -17,7 +17,6 @@
#include "common/memory_image/in-memory-model-data.h"
#include "common/file-utils.h"
-#include "common/memory_image/memory-image-common.h"
#include "util/base/logging.h"
#include "util/strings/stringpiece.h"
@@ -28,14 +27,13 @@
const char InMemoryModelData::kFilePatternPrefix[] = "in-mem-model::";
bool InMemoryModelData::GetTaskSpec(TaskSpec *task_spec) const {
- DataBlobView blob = data_store_.GetData(kTaskSpecDataStoreEntryName);
+ StringPiece blob = data_store_.GetData(kTaskSpecDataStoreEntryName);
if (blob.data() == nullptr) {
TC_LOG(ERROR) << "Can't find data blob for TaskSpec, i.e., entry "
<< kTaskSpecDataStoreEntryName;
return false;
}
- bool parse_status = file_utils::ParseProtoFromMemory(
- blob.to_stringpiece(), task_spec);
+ bool parse_status = file_utils::ParseProtoFromMemory(blob, task_spec);
if (!parse_status) {
TC_LOG(ERROR) << "Error parsing TaskSpec";
return false;
@@ -43,13 +41,5 @@
return true;
}
-StringPiece InMemoryModelData::GetBytesForInputFile(
- const std::string &file_name) const {
- // TODO(salcianu): replace our DataBlobView with StringPiece everywhere.
- DataBlobView blob = data_store_.GetData(file_name);
- return StringPiece(reinterpret_cast<const char *>(blob.data()),
- blob.size());
-}
-
} // namespace nlp_core
} // namespace libtextclassifier
diff --git a/common/memory_image/in-memory-model-data.h b/common/memory_image/in-memory-model-data.h
index 5c26388..91e4436 100644
--- a/common/memory_image/in-memory-model-data.h
+++ b/common/memory_image/in-memory-model-data.h
@@ -62,7 +62,9 @@
// file_pattern for a TaskInput from the TaskSpec (see GetTaskSpec()).
// Returns a StringPiece indicating a memory area with the content bytes. On
// error, returns StringPiece(nullptr, 0).
- StringPiece GetBytesForInputFile(const std::string &file_name) const;
+ StringPiece GetBytesForInputFile(const std::string &file_name) const {
+ return data_store_.GetData(file_name);
+ }
private:
const memory_image::DataStore data_store_;
diff --git a/common/memory_image/low-level-memory-reader.h b/common/memory_image/low-level-memory-reader.h
index 91953d8..c87c772 100644
--- a/common/memory_image/low-level-memory-reader.h
+++ b/common/memory_image/low-level-memory-reader.h
@@ -21,10 +21,10 @@
#include <string>
-#include "base.h"
-#include "common/memory_image/memory-image-common.h"
+#include "util/base/endian.h"
#include "util/base/integral_types.h"
#include "util/base/logging.h"
+#include "util/strings/stringpiece.h"
namespace libtextclassifier {
namespace nlp_core {
@@ -61,7 +61,7 @@
// On success, sets *view to be a view of the relevant bytes: view.data()
// points to the beginning of the string bytes, and view.size() is the number
// of such bytes.
- bool ReadString(DataBlobView *view) {
+ bool ReadString(StringPiece *view) {
uint32 size;
if (!Read(&size, sizeof(size))) {
TC_LOG(ERROR) << "Unable to read std::string size";
@@ -73,15 +73,15 @@
<< " available < " << size << " required ";
return false;
}
- *view = DataBlobView(current_, size);
+ *view = StringPiece(current_, size);
Advance(size);
return true;
}
- // Like ReadString(DataBlobView *) but reads directly into a C++ string,
- // instead of a DataBlobView (StringPiece-like object).
+ // Like ReadString(StringPiece *) but reads directly into a C++ string,
+ // instead of a StringPiece (StringPiece-like object).
bool ReadString(std::string *target) {
- DataBlobView view;
+ StringPiece view;
if (!ReadString(&view)) {
return false;
}
diff --git a/common/memory_image/memory-image-common.h b/common/memory_image/memory-image-common.h
index 2e84116..3a46f49 100644
--- a/common/memory_image/memory-image-common.h
+++ b/common/memory_image/memory-image-common.h
@@ -35,38 +35,6 @@
static const int kDefaultAlignment;
};
-// Read-only "view" of a data blob. Does not own the underlying data; instead,
-// just a small object that points to an area of a memory image.
-//
-// TODO(salcianu): replace this class with StringPiece.
-class DataBlobView {
- public:
- DataBlobView() : DataBlobView(nullptr, 0) {}
-
- DataBlobView(const void *start, size_t size)
- : start_(start), size_(size) {}
-
- // Returns start address of a data blob from a memory image.
- const void *data() const { return start_; }
-
- // Returns number of bytes from the data blob starting at start().
- size_t size() const { return size_; }
-
- StringPiece to_stringpiece() const {
- return StringPiece(reinterpret_cast<const char *>(data()),
- size());
- }
-
- // Returns a std::string containing a copy of the data blob bytes.
- std::string ToString() const {
- return to_stringpiece().ToString();
- }
-
- private:
- const void *start_; // Not owned.
- size_t size_;
-};
-
} // namespace nlp_core
} // namespace libtextclassifier
diff --git a/common/memory_image/memory-image-reader.cc b/common/memory_image/memory-image-reader.cc
index c780412..7e717d5 100644
--- a/common/memory_image/memory-image-reader.cc
+++ b/common/memory_image/memory-image-reader.cc
@@ -18,10 +18,10 @@
#include <string>
-#include "base.h"
#include "common/memory_image/low-level-memory-reader.h"
#include "common/memory_image/memory-image-common.h"
#include "common/memory_image/memory-image.pb.h"
+#include "util/base/endian.h"
#include "util/base/logging.h"
namespace libtextclassifier {
diff --git a/common/memory_image/memory-image-reader.h b/common/memory_image/memory-image-reader.h
index 0a20805..c5954fd 100644
--- a/common/memory_image/memory-image-reader.h
+++ b/common/memory_image/memory-image-reader.h
@@ -22,11 +22,11 @@
#include <string>
#include <vector>
-#include "common/memory_image/memory-image-common.h"
#include "common/memory_image/memory-image.pb.h"
#include "util/base/integral_types.h"
#include "util/base/logging.h"
#include "util/base/macros.h"
+#include "util/strings/stringpiece.h"
namespace libtextclassifier {
namespace nlp_core {
@@ -66,11 +66,11 @@
}
// Returns pointer to the beginning of the data blob #i.
- DataBlobView data_blob_view(int i) const {
+ StringPiece data_blob_view(int i) const {
if ((i < 0) || (i >= num_data_blobs())) {
TC_LOG(ERROR) << "Blob index " << i << " outside range [0, "
<< num_data_blobs() << "); will return empty data chunk";
- return DataBlobView();
+ return StringPiece();
}
return data_blob_views_[i];
}
@@ -81,6 +81,12 @@
return trimmed_proto_serialization_.ToString();
}
+ // Same as above but returns the trimmed proto as a string piece pointing to
+ // the image.
+ StringPiece trimmed_proto_view() const {
+ return trimmed_proto_serialization_;
+ }
+
const MemoryImageHeader &header() { return header_; }
protected:
@@ -101,13 +107,13 @@
MemoryImageHeader header_;
// Binary serialization of the trimmed version of the original proto.
- // Represented as a DataBlobView backed up by the underlying memory image
+ // Represented as a StringPiece backed up by the underlying memory image
// bytes.
- DataBlobView trimmed_proto_serialization_;
+ StringPiece trimmed_proto_serialization_;
- // List of DataBlobView objects for all data blobs from the memory image (in
+ // List of StringPiece objects for all data blobs from the memory image (in
// order).
- std::vector<DataBlobView> data_blob_views_;
+ std::vector<StringPiece> data_blob_views_;
// Memory reading success status.
bool success_;
diff --git a/common/mmap.cc b/common/mmap.cc
index d652b3c..6e15a84 100644
--- a/common/mmap.cc
+++ b/common/mmap.cc
@@ -31,13 +31,9 @@
namespace nlp_core {
namespace {
-inline std::string GetLastSystemError() {
- return std::string(strerror(errno));
-}
+inline std::string GetLastSystemError() { return std::string(strerror(errno)); }
-inline MmapHandle GetErrorMmapHandle() {
- return MmapHandle(nullptr, 0);
-}
+inline MmapHandle GetErrorMmapHandle() { return MmapHandle(nullptr, 0); }
class FileCloser {
public:
@@ -49,11 +45,13 @@
TC_LOG(ERROR) << "Error closing file descriptor: " << last_error;
}
}
+
private:
const int fd_;
TC_DISALLOW_COPY_AND_ASSIGN(FileCloser);
};
+
} // namespace
MmapHandle MmapFile(const std::string &filename) {
@@ -81,7 +79,15 @@
TC_LOG(ERROR) << "Unable to stat fd: " << last_error;
return GetErrorMmapHandle();
}
- size_t file_size_in_bytes = static_cast<size_t>(sb.st_size);
+
+ return MmapFile(fd, /*segment_offset=*/0, /*segment_size=*/sb.st_size);
+}
+
+MmapHandle MmapFile(int fd, int64 segment_offset, int64 segment_size) {
+ static const int64 kPageSize = sysconf(_SC_PAGE_SIZE);
+ const int64 aligned_offset = (segment_offset / kPageSize) * kPageSize;
+ const int64 alignment_shift = segment_offset - aligned_offset;
+ const int64 aligned_length = segment_size + alignment_shift;
// Perform actual mmap.
void *mmap_addr = mmap(
@@ -89,8 +95,7 @@
// Let system pick address for mmapp-ed data.
nullptr,
- // Mmap all bytes from the file.
- file_size_in_bytes,
+ aligned_length,
// One can read / write the mapped data (but see MAP_PRIVATE below).
// Normally, we expect only to read it, but in the future, we may want to
@@ -104,16 +109,15 @@
// Descriptor of file to mmap.
fd,
- // Map bytes right from the beginning of the file. This, and
- // file_size_in_bytes (2nd argument) means we map all bytes from the file.
- 0);
+ aligned_offset);
if (mmap_addr == MAP_FAILED) {
const std::string last_error = GetLastSystemError();
- TC_LOG(ERROR) << "Error while mmaping: " << last_error;
+ TC_LOG(ERROR) << "Error while mmapping: " << last_error;
return GetErrorMmapHandle();
}
- return MmapHandle(mmap_addr, file_size_in_bytes);
+ return MmapHandle(static_cast<char *>(mmap_addr) + alignment_shift,
+ segment_size, /*unmap_addr=*/mmap_addr);
}
bool Unmap(MmapHandle mmap_handle) {
@@ -121,7 +125,7 @@
// Unmapping something that hasn't been mapped is trivially successful.
return true;
}
- if (munmap(mmap_handle.start(), mmap_handle.num_bytes()) != 0) {
+ if (munmap(mmap_handle.unmap_addr(), mmap_handle.num_bytes()) != 0) {
const std::string last_error = GetLastSystemError();
TC_LOG(ERROR) << "Error during Unmap / munmap: " << last_error;
return false;
diff --git a/common/mmap.h b/common/mmap.h
index ec79c9c..69f7b4c 100644
--- a/common/mmap.h
+++ b/common/mmap.h
@@ -21,6 +21,7 @@
#include <string>
+#include "util/base/integral_types.h"
#include "util/strings/stringpiece.h"
namespace libtextclassifier {
@@ -39,12 +40,22 @@
// are ok keeping that file in memory the whole time).
class MmapHandle {
public:
- MmapHandle(void *start, size_t num_bytes)
- : start_(start), num_bytes_(num_bytes) {}
+ MmapHandle(void *start, size_t num_bytes, void *unmap_addr = nullptr)
+ : start_(start), num_bytes_(num_bytes), unmap_addr_(unmap_addr) {}
// Returns start address for the memory area where a file has been mmapped.
void *start() const { return start_; }
+ // Returns address to use for munmap call. If unmap_addr was not specified
+ // the start address is used.
+ void *unmap_addr() const {
+ if (unmap_addr_ != nullptr) {
+ return unmap_addr_;
+ } else {
+ return start_;
+ }
+ }
+
// Returns number of bytes of the memory area from start().
size_t num_bytes() const { return num_bytes_; }
@@ -63,6 +74,9 @@
// See doc for num_bytes().
const size_t num_bytes_;
+
+ // Address to use for unmapping.
+ void *const unmap_addr_;
};
// Maps the full content of a file in memory (using mmap).
@@ -88,6 +102,13 @@
// Like MmapFile(const std::string &filename), but uses a file descriptor.
MmapHandle MmapFile(int fd);
+// Maps a segment of a file to memory. File is given by a file descriptor, and
+// offset (relative to the beginning of the file) and size specify the segment
+// to be mapped. NOTE: Internally, we align the offset for the call to mmap
+// system call to be a multiple of page size, so offset does NOT have to be a
+// multiply of the page size.
+MmapHandle MmapFile(int fd, int64 segment_offset, int64 segment_size);
+
// Unmaps a file mapped using MmapFile. Returns true on success, false
// otherwise.
bool Unmap(MmapHandle mmap_handle);
@@ -99,8 +120,10 @@
explicit ScopedMmap(const std::string &filename)
: handle_(MmapFile(filename)) {}
- explicit ScopedMmap(int fd)
- : handle_(MmapFile(fd)) {}
+ explicit ScopedMmap(int fd) : handle_(MmapFile(fd)) {}
+
+ ScopedMmap(int fd, int segment_offset, int segment_size)
+ : handle_(MmapFile(fd, segment_offset, segment_size)) {}
~ScopedMmap() {
if (handle_.ok()) {
diff --git a/tests/functions.cc b/common/mock_functions.cc
similarity index 95%
rename from tests/functions.cc
rename to common/mock_functions.cc
index 8ea5a8d..c661b70 100644
--- a/tests/functions.cc
+++ b/common/mock_functions.cc
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-#include "tests/functions.h"
+#include "common/mock_functions.h"
#include "common/registry.h"
diff --git a/tests/functions.h b/common/mock_functions.h
similarity index 91%
rename from tests/functions.h
rename to common/mock_functions.h
index b96fe2d..b5bcb07 100644
--- a/tests/functions.h
+++ b/common/mock_functions.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef LIBTEXTCLASSIFIER_TESTS_FUNCTIONS_H_
-#define LIBTEXTCLASSIFIER_TESTS_FUNCTIONS_H_
+#ifndef LIBTEXTCLASSIFIER_COMMON_MOCK_FUNCTIONS_H_
+#define LIBTEXTCLASSIFIER_COMMON_MOCK_FUNCTIONS_H_
#include <math.h>
@@ -24,6 +24,7 @@
namespace libtextclassifier {
namespace nlp_core {
namespace functions {
+
// Abstract double -> double function.
class Function : public RegisterableClass<Function> {
public:
@@ -61,6 +62,7 @@
int Evaluate(int k) override { return k + 1; }
TC_DEFINE_REGISTRATION_METHOD("dec", Dec);
};
+
} // namespace functions
// Should be inside namespace libtextclassifier::nlp_core.
@@ -70,4 +72,4 @@
} // namespace nlp_core
} // namespace libtextclassifier
-#endif // LIBTEXTCLASSIFIER_TESTS_FUNCTIONS_H_
+#endif // LIBTEXTCLASSIFIER_COMMON_MOCK_FUNCTIONS_H_
diff --git a/tests/registry_test.cc b/common/registry_test.cc
similarity index 98%
rename from tests/registry_test.cc
rename to common/registry_test.cc
index 7de4163..d5d7006 100644
--- a/tests/registry_test.cc
+++ b/common/registry_test.cc
@@ -16,7 +16,7 @@
#include <memory>
-#include "tests/functions.h"
+#include "common/mock_functions.h"
#include "gtest/gtest.h"
namespace libtextclassifier {
diff --git a/tests/lang-id_test.cc b/lang_id/lang-id_test.cc
similarity index 99%
rename from tests/lang-id_test.cc
rename to lang_id/lang-id_test.cc
index 3a735c2..2f8aedd 100644
--- a/tests/lang-id_test.cc
+++ b/lang_id/lang-id_test.cc
@@ -21,7 +21,6 @@
#include <utility>
#include <vector>
-#include "base.h"
#include "util/base/logging.h"
#include "gtest/gtest.h"
diff --git a/smartselect/cached-features.h b/smartselect/cached-features.h
index 6490748..990233c 100644
--- a/smartselect/cached-features.h
+++ b/smartselect/cached-features.h
@@ -20,7 +20,6 @@
#include <memory>
#include <vector>
-#include "base.h"
#include "common/vector-span.h"
#include "smartselect/types.h"
diff --git a/tests/cached-features_test.cc b/smartselect/cached-features_test.cc
similarity index 100%
rename from tests/cached-features_test.cc
rename to smartselect/cached-features_test.cc
diff --git a/smartselect/feature-processor.cc b/smartselect/feature-processor.cc
index 1b15982..08f18ea 100644
--- a/smartselect/feature-processor.cc
+++ b/smartselect/feature-processor.cc
@@ -24,9 +24,11 @@
#include "util/base/logging.h"
#include "util/strings/utf8.h"
#include "util/utf8/unicodetext.h"
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
#include "unicode/brkiter.h"
#include "unicode/errorcode.h"
#include "unicode/uchar.h"
+#endif
namespace libtextclassifier {
@@ -51,6 +53,10 @@
extractor_options.remap_digits = options.remap_digits();
extractor_options.lowercase_tokens = options.lowercase_tokens();
+ for (const auto& chargram : options.allowed_chargrams()) {
+ extractor_options.allowed_chargrams.insert(chargram);
+ }
+
return extractor_options;
}
@@ -173,8 +179,10 @@
} // namespace internal
std::string FeatureProcessor::GetDefaultCollection() const {
- if (options_.default_collection() >= options_.collections_size()) {
- TC_LOG(ERROR) << "No collections specified. Returning empty string.";
+ if (options_.default_collection() < 0 ||
+ options_.default_collection() >= options_.collections_size()) {
+ TC_LOG(ERROR)
+ << "Invalid or missing default collection. Returning empty string.";
return "";
}
return options_.collections(options_.default_collection());
@@ -217,18 +225,32 @@
return false;
}
- const int result_begin_token = token_span.first;
- const int result_begin_codepoint =
- tokens[options_.context_size() - result_begin_token].start;
- const int result_end_token = token_span.second;
- const int result_end_codepoint =
- tokens[options_.context_size() + result_end_token].end;
+ const int result_begin_token_index = token_span.first;
+ const Token& result_begin_token =
+ tokens[options_.context_size() - result_begin_token_index];
+ const int result_begin_codepoint = result_begin_token.start;
+ const int result_end_token_index = token_span.second;
+ const Token& result_end_token =
+ tokens[options_.context_size() + result_end_token_index];
+ const int result_end_codepoint = result_end_token.end;
if (result_begin_codepoint == kInvalidIndex ||
result_end_codepoint == kInvalidIndex) {
*span = CodepointSpan({kInvalidIndex, kInvalidIndex});
} else {
- *span = CodepointSpan({result_begin_codepoint, result_end_codepoint});
+ const UnicodeText token_begin_unicode =
+ UTF8ToUnicodeText(result_begin_token.value, /*do_copy=*/false);
+ UnicodeText::const_iterator token_begin = token_begin_unicode.begin();
+ const UnicodeText token_end_unicode =
+ UTF8ToUnicodeText(result_end_token.value, /*do_copy=*/false);
+ UnicodeText::const_iterator token_end = token_end_unicode.end();
+
+ const int begin_ignored = CountIgnoredSpanBoundaryCodepoints(
+ token_begin, token_begin_unicode.end(), /*count_from_beginning=*/true);
+ const int end_ignored = CountIgnoredSpanBoundaryCodepoints(
+ token_end_unicode.begin(), token_end, /*count_from_beginning=*/false);
+ *span = CodepointSpan({result_begin_codepoint + begin_ignored,
+ result_end_codepoint - end_ignored});
}
return true;
}
@@ -274,14 +296,28 @@
// Check that the spanned tokens cover the whole span.
bool tokens_match_span;
+ const CodepointIndex tokens_start = tokens[click_position - span_left].start;
+ const CodepointIndex tokens_end = tokens[click_position + span_right].end;
if (options_.snap_label_span_boundaries_to_containing_tokens()) {
- tokens_match_span =
- tokens[click_position - span_left].start <= span.first &&
- tokens[click_position + span_right].end >= span.second;
+ tokens_match_span = tokens_start <= span.first && tokens_end >= span.second;
} else {
- tokens_match_span =
- tokens[click_position - span_left].start == span.first &&
- tokens[click_position + span_right].end == span.second;
+ const UnicodeText token_left_unicode = UTF8ToUnicodeText(
+ tokens[click_position - span_left].value, /*do_copy=*/false);
+ const UnicodeText token_right_unicode = UTF8ToUnicodeText(
+ tokens[click_position + span_right].value, /*do_copy=*/false);
+
+ UnicodeText::const_iterator span_begin = token_left_unicode.begin();
+ UnicodeText::const_iterator span_end = token_right_unicode.end();
+
+ const int num_punctuation_start = CountIgnoredSpanBoundaryCodepoints(
+ span_begin, token_left_unicode.end(), /*count_from_beginning=*/true);
+ const int num_punctuation_end = CountIgnoredSpanBoundaryCodepoints(
+ token_right_unicode.begin(), span_end, /*count_from_beginning=*/false);
+
+ tokens_match_span = tokens_start <= span.first &&
+ tokens_start + num_punctuation_start >= span.first &&
+ tokens_end >= span.second &&
+ tokens_end - num_punctuation_end <= span.second;
}
if (tokens_match_span) {
@@ -453,6 +489,77 @@
});
}
+void FeatureProcessor::PrepareIgnoredSpanBoundaryCodepoints() {
+ for (const int codepoint : options_.ignored_span_boundary_codepoints()) {
+ ignored_span_boundary_codepoints_.insert(codepoint);
+ }
+}
+
+int FeatureProcessor::CountIgnoredSpanBoundaryCodepoints(
+ const UnicodeText::const_iterator& span_start,
+ const UnicodeText::const_iterator& span_end,
+ bool count_from_beginning) const {
+ if (span_start == span_end) {
+ return 0;
+ }
+
+ UnicodeText::const_iterator it;
+ UnicodeText::const_iterator it_last;
+ if (count_from_beginning) {
+ it = span_start;
+ it_last = span_end;
+ // We can assume that the string is non-zero length because of the check
+ // above, thus the decrement is always valid here.
+ --it_last;
+ } else {
+ it = span_end;
+ it_last = span_start;
+ // We can assume that the string is non-zero length because of the check
+ // above, thus the decrement is always valid here.
+ --it;
+ }
+
+ // Move until we encounter a non-ignored character.
+ int num_ignored = 0;
+ while (ignored_span_boundary_codepoints_.find(*it) !=
+ ignored_span_boundary_codepoints_.end()) {
+ ++num_ignored;
+
+ if (it == it_last) {
+ break;
+ }
+
+ if (count_from_beginning) {
+ ++it;
+ } else {
+ --it;
+ }
+ }
+
+ return num_ignored;
+}
+
+CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
+ const std::string& context, CodepointSpan span) const {
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+ UnicodeText::const_iterator span_begin = context_unicode.begin();
+ std::advance(span_begin, span.first);
+ UnicodeText::const_iterator span_end = context_unicode.begin();
+ std::advance(span_end, span.second);
+
+ const int start_offset = CountIgnoredSpanBoundaryCodepoints(
+ span_begin, span_end, /*count_from_beginning=*/true);
+ const int end_offset = CountIgnoredSpanBoundaryCodepoints(
+ span_begin, span_end, /*count_from_beginning=*/false);
+
+ if (span.first + start_offset < span.second - end_offset) {
+ return {span.first + start_offset, span.second - end_offset};
+ } else {
+ return {span.first, span.first};
+ }
+}
+
float FeatureProcessor::SupportedCodepointsRatio(
int click_pos, const std::vector<Token>& tokens) const {
int num_supported = 0;
@@ -614,6 +721,10 @@
}
}
+ if (relative_click_span == std::make_pair(kInvalidIndex, kInvalidIndex)) {
+ relative_click_span = {tokens->size() - 1, tokens->size() - 1};
+ }
+
internal::StripOrPadTokens(relative_click_span, options_.context_size(),
tokens, click_pos);
@@ -621,8 +732,8 @@
const float supported_codepoint_ratio =
SupportedCodepointsRatio(*click_pos, *tokens);
if (supported_codepoint_ratio < options_.min_supported_codepoint_ratio()) {
- TC_LOG(INFO) << "Not enough supported codepoints in the context: "
- << supported_codepoint_ratio;
+ TC_VLOG(1) << "Not enough supported codepoints in the context: "
+ << supported_codepoint_ratio;
return false;
}
}
@@ -658,6 +769,7 @@
bool FeatureProcessor::ICUTokenize(const std::string& context,
std::vector<Token>* result) const {
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
icu::ErrorCode status;
icu::UnicodeString unicode_text = icu::UnicodeString::fromUTF8(context);
std::unique_ptr<icu::BreakIterator> break_iterator(
@@ -699,6 +811,10 @@
}
return true;
+#else
+ TC_LOG(WARNING) << "Can't tokenize, ICU not supported";
+ return false;
+#endif
}
void FeatureProcessor::InternalRetokenize(const std::string& context,
@@ -758,6 +874,8 @@
// Run the tokenizer and update the token bounds to reflect the offset of the
// substring.
std::vector<Token> tokens = tokenizer_.Tokenize(text);
+ // Avoids progressive capacity increases in the for loop.
+ result->reserve(result->size() + tokens.size());
for (Token& token : tokens) {
token.start += span.first;
token.end += span.first;
diff --git a/smartselect/feature-processor.h b/smartselect/feature-processor.h
index 2c64b67..a39a789 100644
--- a/smartselect/feature-processor.h
+++ b/smartselect/feature-processor.h
@@ -20,6 +20,7 @@
#define LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_
#include <memory>
+#include <set>
#include <string>
#include <vector>
@@ -104,6 +105,7 @@
{options.internal_tokenizer_codepoint_ranges().begin(),
options.internal_tokenizer_codepoint_ranges().end()},
&internal_tokenizer_codepoint_ranges_);
+ PrepareIgnoredSpanBoundaryCodepoints();
}
explicit FeatureProcessor(const std::string& serialized_options)
@@ -137,6 +139,8 @@
// Extracts features as a CachedFeatures object that can be used for repeated
// inference over token spans in the given context.
+ // When relative_click_span == {kInvalidIndex, kInvalidIndex} then all tokens
+ // extracted from context will be considered.
bool ExtractFeatures(const std::string& context, CodepointSpan input_span,
TokenSpan relative_click_span,
const FeatureVectorFn& feature_vector_fn,
@@ -155,6 +159,12 @@
return feature_extractor_.DenseFeaturesCount();
}
+ // Strips boundary codepoints from the span in context and returns the new
+ // start and end indices. If the span comprises entirely of boundary
+ // codepoints, the first index of span is returned for both indices.
+ CodepointSpan StripBoundaryCodepoints(const std::string& context,
+ CodepointSpan span) const;
+
protected:
// Represents a codepoint range [start, end).
struct CodepointRange {
@@ -207,6 +217,18 @@
bool IsCodepointInRanges(
int codepoint, const std::vector<CodepointRange>& codepoint_ranges) const;
+ void PrepareIgnoredSpanBoundaryCodepoints();
+
+ // Counts the number of span boundary codepoints. If count_from_beginning is
+ // True, the counting will start at the span_start iterator (inclusive) and at
+ // maximum end at span_end (exclusive). If count_from_beginning is True, the
+ // counting will start from span_end (exclusive) and end at span_start
+ // (inclusive).
+ int CountIgnoredSpanBoundaryCodepoints(
+ const UnicodeText::const_iterator& span_start,
+ const UnicodeText::const_iterator& span_end,
+ bool count_from_beginning) const;
+
// Finds the center token index in tokens vector, using the method defined
// in options_.
int FindCenterToken(CodepointSpan span,
@@ -240,6 +262,10 @@
std::vector<CodepointRange> internal_tokenizer_codepoint_ranges_;
private:
+ // Set of codepoints that will be stripped from beginning and end of
+ // predicted spans.
+ std::set<int32> ignored_span_boundary_codepoints_;
+
const FeatureProcessorOptions options_;
// Mapping between token selection spans and labels ids.
diff --git a/tests/feature-processor_test.cc b/smartselect/feature-processor_test.cc
similarity index 77%
rename from tests/feature-processor_test.cc
rename to smartselect/feature-processor_test.cc
index 4e27afc..1a9b9da 100644
--- a/tests/feature-processor_test.cc
+++ b/smartselect/feature-processor_test.cc
@@ -205,6 +205,7 @@
using FeatureProcessor::SupportedCodepointsRatio;
using FeatureProcessor::IsCodepointInRanges;
using FeatureProcessor::ICUTokenize;
+ using FeatureProcessor::CountIgnoredSpanBoundaryCodepoints;
using FeatureProcessor::supported_codepoint_ranges_;
};
@@ -270,6 +271,68 @@
EXPECT_EQ(label2, label3);
}
+TEST(FeatureProcessorTest, SpanToLabelIgnoresPunctuation) {
+ FeatureProcessorOptions options;
+ options.set_context_size(1);
+ options.set_max_selection_span(1);
+ options.set_snap_label_span_boundaries_to_containing_tokens(false);
+
+ TokenizationCodepointRange* config =
+ options.add_tokenization_codepoint_config();
+ config->set_start(32);
+ config->set_end(33);
+ config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
+
+ TestingFeatureProcessor feature_processor(options);
+ std::vector<Token> tokens = feature_processor.Tokenize("one, two, three");
+ ASSERT_EQ(3, tokens.size());
+ int label;
+ ASSERT_TRUE(feature_processor.SpanToLabel({5, 8}, tokens, &label));
+ EXPECT_EQ(kInvalidLabel, label);
+ ASSERT_TRUE(feature_processor.SpanToLabel({5, 9}, tokens, &label));
+ EXPECT_NE(kInvalidLabel, label);
+ TokenSpan token_span;
+ feature_processor.LabelToTokenSpan(label, &token_span);
+ EXPECT_EQ(0, token_span.first);
+ EXPECT_EQ(0, token_span.second);
+
+ // Reconfigure with snapping enabled.
+ options.set_snap_label_span_boundaries_to_containing_tokens(true);
+ TestingFeatureProcessor feature_processor2(options);
+ int label2;
+ ASSERT_TRUE(feature_processor2.SpanToLabel({5, 8}, tokens, &label2));
+ EXPECT_EQ(label, label2);
+ ASSERT_TRUE(feature_processor2.SpanToLabel({6, 9}, tokens, &label2));
+ EXPECT_EQ(label, label2);
+ ASSERT_TRUE(feature_processor2.SpanToLabel({5, 9}, tokens, &label2));
+ EXPECT_EQ(label, label2);
+
+ // Cross a token boundary.
+ ASSERT_TRUE(feature_processor2.SpanToLabel({4, 9}, tokens, &label2));
+ EXPECT_EQ(kInvalidLabel, label2);
+ ASSERT_TRUE(feature_processor2.SpanToLabel({5, 10}, tokens, &label2));
+ EXPECT_EQ(kInvalidLabel, label2);
+
+ // Multiple tokens.
+ options.set_context_size(2);
+ options.set_max_selection_span(2);
+ TestingFeatureProcessor feature_processor3(options);
+ tokens = feature_processor3.Tokenize("zero, one, two, three, four");
+ ASSERT_TRUE(feature_processor3.SpanToLabel({6, 15}, tokens, &label2));
+ EXPECT_NE(kInvalidLabel, label2);
+ feature_processor3.LabelToTokenSpan(label2, &token_span);
+ EXPECT_EQ(1, token_span.first);
+ EXPECT_EQ(0, token_span.second);
+
+ int label3;
+ ASSERT_TRUE(feature_processor3.SpanToLabel({6, 14}, tokens, &label3));
+ EXPECT_EQ(label2, label3);
+ ASSERT_TRUE(feature_processor3.SpanToLabel({6, 13}, tokens, &label3));
+ EXPECT_EQ(label2, label3);
+ ASSERT_TRUE(feature_processor3.SpanToLabel({7, 13}, tokens, &label3));
+ EXPECT_EQ(label2, label3);
+}
+
TEST(FeatureProcessorTest, CenterTokenFromClick) {
int token_index;
@@ -610,5 +673,114 @@
// clang-format on
}
+TEST(FeatureProcessorTest, IgnoredSpanBoundaryCodepoints) {
+ FeatureProcessorOptions options;
+ options.add_ignored_span_boundary_codepoints('.');
+ options.add_ignored_span_boundary_codepoints(',');
+ options.add_ignored_span_boundary_codepoints('[');
+ options.add_ignored_span_boundary_codepoints(']');
+
+ TestingFeatureProcessor feature_processor(options);
+
+ const std::string text1_utf8 = "ěščř";
+ const UnicodeText text1 = UTF8ToUnicodeText(text1_utf8, /*do_copy=*/false);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text1.begin(), text1.end(),
+ /*count_from_beginning=*/true),
+ 0);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text1.begin(), text1.end(),
+ /*count_from_beginning=*/false),
+ 0);
+
+ const std::string text2_utf8 = ".,abčd";
+ const UnicodeText text2 = UTF8ToUnicodeText(text2_utf8, /*do_copy=*/false);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text2.begin(), text2.end(),
+ /*count_from_beginning=*/true),
+ 2);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text2.begin(), text2.end(),
+ /*count_from_beginning=*/false),
+ 0);
+
+ const std::string text3_utf8 = ".,abčd[]";
+ const UnicodeText text3 = UTF8ToUnicodeText(text3_utf8, /*do_copy=*/false);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text3.begin(), text3.end(),
+ /*count_from_beginning=*/true),
+ 2);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text3.begin(), text3.end(),
+ /*count_from_beginning=*/false),
+ 2);
+
+ const std::string text4_utf8 = "[abčd]";
+ const UnicodeText text4 = UTF8ToUnicodeText(text4_utf8, /*do_copy=*/false);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text4.begin(), text4.end(),
+ /*count_from_beginning=*/true),
+ 1);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text4.begin(), text4.end(),
+ /*count_from_beginning=*/false),
+ 1);
+
+ const std::string text5_utf8 = "";
+ const UnicodeText text5 = UTF8ToUnicodeText(text5_utf8, /*do_copy=*/false);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text5.begin(), text5.end(),
+ /*count_from_beginning=*/true),
+ 0);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text5.begin(), text5.end(),
+ /*count_from_beginning=*/false),
+ 0);
+
+ const std::string text6_utf8 = "012345ěščř";
+ const UnicodeText text6 = UTF8ToUnicodeText(text6_utf8, /*do_copy=*/false);
+ UnicodeText::const_iterator text6_begin = text6.begin();
+ std::advance(text6_begin, 6);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text6_begin, text6.end(),
+ /*count_from_beginning=*/true),
+ 0);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text6_begin, text6.end(),
+ /*count_from_beginning=*/false),
+ 0);
+
+ const std::string text7_utf8 = "012345.,ěščř";
+ const UnicodeText text7 = UTF8ToUnicodeText(text7_utf8, /*do_copy=*/false);
+ UnicodeText::const_iterator text7_begin = text7.begin();
+ std::advance(text7_begin, 6);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text7_begin, text7.end(),
+ /*count_from_beginning=*/true),
+ 2);
+ UnicodeText::const_iterator text7_end = text7.begin();
+ std::advance(text7_end, 8);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text7.begin(), text7_end,
+ /*count_from_beginning=*/false),
+ 2);
+
+ // Test not stripping.
+ EXPECT_EQ(feature_processor.StripBoundaryCodepoints(
+ "Hello [[[Wořld]] or not?", {0, 24}),
+ std::make_pair(0, 24));
+ // Test basic stripping.
+ EXPECT_EQ(feature_processor.StripBoundaryCodepoints(
+ "Hello [[[Wořld]] or not?", {6, 16}),
+ std::make_pair(9, 14));
+ // Test stripping when everything is stripped.
+ EXPECT_EQ(
+ feature_processor.StripBoundaryCodepoints("Hello [[[]] or not?", {6, 11}),
+ std::make_pair(6, 6));
+ // Test stripping empty string.
+ EXPECT_EQ(feature_processor.StripBoundaryCodepoints("", {0, 0}),
+ std::make_pair(0, 0));
+}
+
} // namespace
} // namespace libtextclassifier
diff --git a/smartselect/model-params.cc b/smartselect/model-params.cc
index 9a31bab..65c4f93 100644
--- a/smartselect/model-params.cc
+++ b/smartselect/model-params.cc
@@ -43,6 +43,7 @@
reader.trimmed_proto().GetExtension(feature_processor_extension_id);
// If no tokenization codepoint config is present, tokenize on space.
+ // TODO(zilka): Remove the default config.
if (feature_processor_options.tokenization_codepoint_config_size() == 0) {
TokenizationCodepointRange* config;
// New line character.
@@ -67,42 +68,16 @@
if (reader.trimmed_proto().HasExtension(selection_options_extension_id)) {
selection_options =
reader.trimmed_proto().GetExtension(selection_options_extension_id);
- } else {
- // Default values when SelectionModelOptions is not present.
- for (const auto codepoint_pair : std::vector<std::pair<int, int>>(
- {{33, 35}, {37, 39}, {42, 42}, {44, 47},
- {58, 59}, {63, 64}, {91, 93}, {95, 95},
- {123, 123}, {125, 125}, {161, 161}, {171, 171},
- {183, 183}, {187, 187}, {191, 191}, {894, 894},
- {903, 903}, {1370, 1375}, {1417, 1418}, {1470, 1470},
- {1472, 1472}, {1475, 1475}, {1478, 1478}, {1523, 1524},
- {1548, 1549}, {1563, 1563}, {1566, 1567}, {1642, 1645},
- {1748, 1748}, {1792, 1805}, {2404, 2405}, {2416, 2416},
- {3572, 3572}, {3663, 3663}, {3674, 3675}, {3844, 3858},
- {3898, 3901}, {3973, 3973}, {4048, 4049}, {4170, 4175},
- {4347, 4347}, {4961, 4968}, {5741, 5742}, {5787, 5788},
- {5867, 5869}, {5941, 5942}, {6100, 6102}, {6104, 6106},
- {6144, 6154}, {6468, 6469}, {6622, 6623}, {6686, 6687},
- {8208, 8231}, {8240, 8259}, {8261, 8273}, {8275, 8286},
- {8317, 8318}, {8333, 8334}, {9001, 9002}, {9140, 9142},
- {10088, 10101}, {10181, 10182}, {10214, 10219}, {10627, 10648},
- {10712, 10715}, {10748, 10749}, {11513, 11516}, {11518, 11519},
- {11776, 11799}, {11804, 11805}, {12289, 12291}, {12296, 12305},
- {12308, 12319}, {12336, 12336}, {12349, 12349}, {12448, 12448},
- {12539, 12539}, {64830, 64831}, {65040, 65049}, {65072, 65106},
- {65108, 65121}, {65123, 65123}, {65128, 65128}, {65130, 65131},
- {65281, 65283}, {65285, 65290}, {65292, 65295}, {65306, 65307},
- {65311, 65312}, {65339, 65341}, {65343, 65343}, {65371, 65371},
- {65373, 65373}, {65375, 65381}, {65792, 65793}, {66463, 66463},
- {68176, 68184}})) {
- for (int i = codepoint_pair.first; i <= codepoint_pair.second; i++) {
- selection_options.add_punctuation_to_strip(i);
- }
- selection_options.set_strip_punctuation(true);
- selection_options.set_enforce_symmetry(true);
- selection_options.set_symmetry_context_size(
- feature_processor_options.context_size() * 2);
+
+ // For backward compatibility with the current models.
+ if (!feature_processor_options.ignored_span_boundary_codepoints_size()) {
+ *feature_processor_options.mutable_ignored_span_boundary_codepoints() =
+ selection_options.deprecated_punctuation_to_strip();
}
+ } else {
+ selection_options.set_enforce_symmetry(true);
+ selection_options.set_symmetry_context_size(
+ feature_processor_options.context_size() * 2);
}
SharingModelOptions sharing_options;
diff --git a/smartselect/model-parser.cc b/smartselect/model-parser.cc
new file mode 100644
index 0000000..0cf05e3
--- /dev/null
+++ b/smartselect/model-parser.cc
@@ -0,0 +1,91 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "smartselect/model-parser.h"
+#include "util/base/endian.h"
+
+namespace libtextclassifier {
+namespace {
+
+// Small helper class for parsing the merged model format.
+// The merged model consists of interleaved <int32 data_size, char* data>
+// segments.
+class MergedModelParser {
+ public:
+ MergedModelParser(const void* addr, const int size)
+ : addr_(reinterpret_cast<const char*>(addr)), size_(size), pos_(addr_) {}
+
+ bool ReadBytesAndAdvance(int num_bytes, const char** result) {
+ const char* read_addr = pos_;
+ if (Advance(num_bytes)) {
+ *result = read_addr;
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ bool ReadInt32AndAdvance(int* result) {
+ const char* read_addr = pos_;
+ if (Advance(sizeof(int))) {
+ *result =
+ LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(read_addr));
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ bool IsDone() { return pos_ == addr_ + size_; }
+
+ private:
+ bool Advance(int num_bytes) {
+ pos_ += num_bytes;
+ return pos_ <= addr_ + size_;
+ }
+
+ const char* addr_;
+ const int size_;
+ const char* pos_;
+};
+
+} // namespace
+
+bool ParseMergedModel(const void* addr, const int size,
+ const char** selection_model, int* selection_model_length,
+ const char** sharing_model, int* sharing_model_length) {
+ MergedModelParser parser(addr, size);
+
+ if (!parser.ReadInt32AndAdvance(selection_model_length)) {
+ return false;
+ }
+
+ if (!parser.ReadBytesAndAdvance(*selection_model_length, selection_model)) {
+ return false;
+ }
+
+ if (!parser.ReadInt32AndAdvance(sharing_model_length)) {
+ return false;
+ }
+
+ if (!parser.ReadBytesAndAdvance(*sharing_model_length, sharing_model)) {
+ return false;
+ }
+
+ return parser.IsDone();
+}
+
+} // namespace libtextclassifier
diff --git a/tests/functions.cc b/smartselect/model-parser.h
similarity index 62%
copy from tests/functions.cc
copy to smartselect/model-parser.h
index 8ea5a8d..801262f 100644
--- a/tests/functions.cc
+++ b/smartselect/model-parser.h
@@ -14,16 +14,16 @@
* limitations under the License.
*/
-#include "tests/functions.h"
-
-#include "common/registry.h"
+#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARSER_H_
+#define LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARSER_H_
namespace libtextclassifier {
-namespace nlp_core {
-TC_DEFINE_CLASS_REGISTRY_NAME("function", functions::Function);
+// Parse a merged model image.
+bool ParseMergedModel(const void* addr, const int size,
+ const char** selection_model, int* selection_model_length,
+ const char** sharing_model, int* sharing_model_length);
-TC_DEFINE_CLASS_REGISTRY_NAME("int-function", functions::IntFunction);
-
-} // namespace nlp_core
} // namespace libtextclassifier
+
+#endif // LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARSER_H_
diff --git a/smartselect/text-classification-model.cc b/smartselect/text-classification-model.cc
index dee8f8b..3e5068d 100644
--- a/smartselect/text-classification-model.cc
+++ b/smartselect/text-classification-model.cc
@@ -16,6 +16,7 @@
#include "smartselect/text-classification-model.h"
+#include <cctype>
#include <cmath>
#include <iterator>
#include <numeric>
@@ -26,10 +27,14 @@
#include "common/memory_image/memory-image-reader.h"
#include "common/mmap.h"
#include "common/softmax.h"
+#include "smartselect/model-parser.h"
#include "smartselect/text-classification-model.pb.h"
#include "util/base/logging.h"
#include "util/utf8/unicodetext.h"
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
+#include "unicode/regex.h"
#include "unicode/uchar.h"
+#endif
namespace libtextclassifier {
@@ -49,65 +54,60 @@
const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false);
for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) {
if (i >= selection_indices.first && i < selection_indices.second &&
- u_isdigit(*it)) {
+ isdigit(*it)) {
++count;
}
}
return count;
}
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
+bool MatchesRegex(const icu::RegexPattern* regex, const std::string& context) {
+ const icu::UnicodeString unicode_context(context.c_str(), context.size(),
+ "utf-8");
+ UErrorCode status = U_ZERO_ERROR;
+ std::unique_ptr<icu::RegexMatcher> matcher(
+ regex->matcher(unicode_context, status));
+ return matcher->matches(0 /* start */, status);
+}
+#endif
+
} // namespace
-CodepointSpan TextClassificationModel::StripPunctuation(
- CodepointSpan selection, const std::string& context) const {
- UnicodeText context_unicode = UTF8ToUnicodeText(context, /*do_copy=*/false);
- int context_length =
- std::distance(context_unicode.begin(), context_unicode.end());
-
- // Check that the indices are valid.
- if (selection.first < 0 || selection.first > context_length ||
- selection.second < 0 || selection.second > context_length) {
- return selection;
- }
-
- // Move the left border until we encounter a non-punctuation character.
- UnicodeText::const_iterator it_from_begin = context_unicode.begin();
- std::advance(it_from_begin, selection.first);
- for (; punctuation_to_strip_.find(*it_from_begin) !=
- punctuation_to_strip_.end();
- ++it_from_begin, ++selection.first) {
- }
-
- // Unless we are already at the end, move the right border until we encounter
- // a non-punctuation character.
- UnicodeText::const_iterator it_from_end = context_unicode.begin();
- std::advance(it_from_end, selection.second);
- if (it_from_begin != it_from_end) {
- --it_from_end;
- for (; punctuation_to_strip_.find(*it_from_end) !=
- punctuation_to_strip_.end();
- --it_from_end, --selection.second) {
- }
- return selection;
- } else {
- // When the token is all punctuation.
- return {0, 0};
- }
+TextClassificationModel::TextClassificationModel(const std::string& path)
+ : mmap_(new nlp_core::ScopedMmap(path)) {
+ InitFromMmap();
}
-TextClassificationModel::TextClassificationModel(int fd) : mmap_(fd) {
- initialized_ = LoadModels(mmap_.handle());
+TextClassificationModel::TextClassificationModel(int fd)
+ : mmap_(new nlp_core::ScopedMmap(fd)) {
+ InitFromMmap();
+}
+
+TextClassificationModel::TextClassificationModel(int fd, int offset, int size)
+ : mmap_(new nlp_core::ScopedMmap(fd, offset, size)) {
+ InitFromMmap();
+}
+
+TextClassificationModel::TextClassificationModel(const void* addr, int size) {
+ initialized_ = LoadModels(addr, size);
if (!initialized_) {
TC_LOG(ERROR) << "Failed to load models";
return;
}
+}
- selection_options_ = selection_params_->GetSelectionModelOptions();
- for (const int codepoint : selection_options_.punctuation_to_strip()) {
- punctuation_to_strip_.insert(codepoint);
+void TextClassificationModel::InitFromMmap() {
+ if (!mmap_->handle().ok()) {
+ return;
}
- sharing_options_ = selection_params_->GetSharingModelOptions();
+ initialized_ =
+ LoadModels(mmap_->handle().start(), mmap_->handle().num_bytes());
+ if (!initialized_) {
+ TC_LOG(ERROR) << "Failed to load models";
+ return;
+ }
}
namespace {
@@ -151,40 +151,23 @@
};
}
-void ParseMergedModel(const MmapHandle& mmap_handle,
- const char** selection_model, int* selection_model_length,
- const char** sharing_model, int* sharing_model_length) {
- // Read the length of the selection model.
- const char* model_data = reinterpret_cast<const char*>(mmap_handle.start());
- *selection_model_length =
- LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(model_data));
- model_data += sizeof(*selection_model_length);
- *selection_model = model_data;
- model_data += *selection_model_length;
-
- *sharing_model_length =
- LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(model_data));
- model_data += sizeof(*sharing_model_length);
- *sharing_model = model_data;
-}
-
} // namespace
-bool TextClassificationModel::LoadModels(const MmapHandle& mmap_handle) {
- if (!mmap_handle.ok()) {
- return false;
- }
-
+bool TextClassificationModel::LoadModels(const void* addr, int size) {
const char *selection_model, *sharing_model;
int selection_model_length, sharing_model_length;
- ParseMergedModel(mmap_handle, &selection_model, &selection_model_length,
- &sharing_model, &sharing_model_length);
+ if (!ParseMergedModel(addr, size, &selection_model, &selection_model_length,
+ &sharing_model, &sharing_model_length)) {
+ TC_LOG(ERROR) << "Couldn't parse the model.";
+ return false;
+ }
selection_params_.reset(
ModelParamsBuilder(selection_model, selection_model_length, nullptr));
if (!selection_params_.get()) {
return false;
}
+ selection_options_ = selection_params_->GetSelectionModelOptions();
selection_network_.reset(new EmbeddingNetwork(selection_params_.get()));
selection_feature_processor_.reset(
new FeatureProcessor(selection_params_->GetFeatureProcessorOptions()));
@@ -197,12 +180,35 @@
if (!sharing_params_.get()) {
return false;
}
+ sharing_options_ = selection_params_->GetSharingModelOptions();
sharing_network_.reset(new EmbeddingNetwork(sharing_params_.get()));
sharing_feature_processor_.reset(
new FeatureProcessor(sharing_params_->GetFeatureProcessorOptions()));
sharing_feature_fn_ = CreateFeatureVectorFn(
*sharing_network_, sharing_network_->EmbeddingSize(0));
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
+ // Initialize pattern recognizers.
+ for (const auto& regex_pattern : sharing_options_.regex_pattern()) {
+ UErrorCode status = U_ZERO_ERROR;
+ std::unique_ptr<icu::RegexPattern> compiled_pattern(
+ icu::RegexPattern::compile(
+ icu::UnicodeString(regex_pattern.pattern().c_str(),
+ regex_pattern.pattern().size(), "utf-8"),
+ 0 /* flags */, status));
+ if (U_FAILURE(status)) {
+ TC_LOG(WARNING) << "Failed to load pattern" << regex_pattern.pattern();
+ } else {
+ regex_patterns_.push_back(
+ {regex_pattern.collection_name(), std::move(compiled_pattern)});
+ }
+ }
+#else
+ if (sharing_options_.regex_pattern_size() > 0) {
+ TC_LOG(WARNING) << "ICU not supported regexp matchers ignored.";
+ }
+#endif
+
return true;
}
@@ -215,8 +221,12 @@
const char *selection_model, *sharing_model;
int selection_model_length, sharing_model_length;
- ParseMergedModel(mmap.handle(), &selection_model, &selection_model_length,
- &sharing_model, &sharing_model_length);
+ if (!ParseMergedModel(mmap.handle().start(), mmap.handle().num_bytes(),
+ &selection_model, &selection_model_length,
+ &sharing_model, &sharing_model_length)) {
+ TC_LOG(ERROR) << "Couldn't parse merged model.";
+ return false;
+ }
MemoryImageReader<EmbeddingNetworkProto> reader(selection_model,
selection_model_length);
@@ -245,14 +255,14 @@
CreateFeatureVectorFn(network, embedding_size),
embedding_size + feature_processor.DenseFeaturesCount(), &tokens,
&click_pos, &cached_features)) {
- TC_LOG(ERROR) << "Could not extract features.";
+ TC_VLOG(1) << "Could not extract features.";
return {};
}
VectorSpan<float> features;
VectorSpan<Token> output_tokens;
if (!cached_features->Get(click_pos, &features, &output_tokens)) {
- TC_LOG(ERROR) << "Could not extract features.";
+ TC_VLOG(1) << "Could not extract features.";
return {};
}
@@ -277,9 +287,9 @@
}
if (std::get<0>(click_indices) >= std::get<1>(click_indices)) {
- TC_LOG(ERROR) << "Trying to run SuggestSelection with invalid indices:"
- << std::get<0>(click_indices) << " "
- << std::get<1>(click_indices);
+ TC_VLOG(1) << "Trying to run SuggestSelection with invalid indices:"
+ << std::get<0>(click_indices) << " "
+ << std::get<1>(click_indices);
return click_indices;
}
@@ -300,28 +310,32 @@
std::tie(result, score) = SuggestSelectionInternal(context, click_indices);
}
- if (selection_options_.strip_punctuation()) {
- result = StripPunctuation(result, context);
- }
-
return result;
}
namespace {
-std::pair<CodepointSpan, float> BestSelectionSpan(
- CodepointSpan original_click_indices, const std::vector<float>& scores,
- const std::vector<CodepointSpan>& selection_label_spans) {
+int BestPrediction(const std::vector<float>& scores) {
if (!scores.empty()) {
const int prediction =
std::max_element(scores.begin(), scores.end()) - scores.begin();
+ return prediction;
+ } else {
+ return kInvalidLabel;
+ }
+}
+
+std::pair<CodepointSpan, float> BestSelectionSpan(
+ CodepointSpan original_click_indices, const std::vector<float>& scores,
+ const std::vector<CodepointSpan>& selection_label_spans) {
+ const int prediction = BestPrediction(scores);
+ if (prediction != kInvalidLabel) {
std::pair<CodepointIndex, CodepointIndex> selection =
selection_label_spans[prediction];
if (selection.first == kInvalidIndex || selection.second == kInvalidIndex) {
- TC_LOG(ERROR) << "Invalid indices predicted, returning input: "
- << prediction << " " << selection.first << " "
- << selection.second;
+ TC_VLOG(1) << "Invalid indices predicted, returning input: " << prediction
+ << " " << selection.first << " " << selection.second;
return {original_click_indices, -1.0};
}
@@ -367,86 +381,17 @@
CodepointSpan TextClassificationModel::SuggestSelectionSymmetrical(
const std::string& context, CodepointSpan click_indices) const {
const int symmetry_context_size = selection_options_.symmetry_context_size();
- std::vector<Token> tokens;
- std::unique_ptr<CachedFeatures> cached_features;
- int click_index;
- int embedding_size = selection_network_->EmbeddingSize(0);
- if (!selection_feature_processor_->ExtractFeatures(
- context, click_indices, /*relative_click_span=*/
- {symmetry_context_size, symmetry_context_size + 1},
- selection_feature_fn_,
- embedding_size + selection_feature_processor_->DenseFeaturesCount(),
- &tokens, &click_index, &cached_features)) {
- TC_LOG(ERROR) << "Couldn't ExtractFeatures.";
- return click_indices;
- }
-
- // Scan in the symmetry context for selection span proposals.
- std::vector<std::pair<CodepointSpan, float>> proposals;
-
- for (int i = -symmetry_context_size; i < symmetry_context_size + 1; ++i) {
- const int token_index = click_index + i;
- if (token_index >= 0 && token_index < tokens.size() &&
- !tokens[token_index].is_padding) {
- float score;
- VectorSpan<float> features;
- VectorSpan<Token> output_tokens;
-
- CodepointSpan span;
- if (cached_features->Get(token_index, &features, &output_tokens)) {
- std::vector<float> scores;
- selection_network_->ComputeLogits(features, &scores);
-
- std::vector<CodepointSpan> selection_label_spans;
- if (selection_feature_processor_->SelectionLabelSpans(
- output_tokens, &selection_label_spans)) {
- scores = nlp_core::ComputeSoftmax(scores);
- std::tie(span, score) =
- BestSelectionSpan(click_indices, scores, selection_label_spans);
- if (span.first != kInvalidIndex && span.second != kInvalidIndex &&
- score >= 0) {
- proposals.push_back({span, score});
- }
- }
- }
+ std::vector<CodepointSpan> chunks = Chunk(
+ context, click_indices, {symmetry_context_size, symmetry_context_size});
+ for (const CodepointSpan& chunk : chunks) {
+ // If chunk and click indices have an overlap, return the chunk.
+ if (!(click_indices.first >= chunk.second ||
+ click_indices.second <= chunk.first)) {
+ return chunk;
}
}
- // Sort selection span proposals by their respective probabilities.
- std::sort(
- proposals.begin(), proposals.end(),
- [](std::pair<CodepointSpan, float> a, std::pair<CodepointSpan, float> b) {
- return a.second > b.second;
- });
-
- // Go from the highest-scoring proposal and claim tokens. Tokens are marked as
- // claimed by the higher-scoring selection proposals, so that the
- // lower-scoring ones cannot use them. Returns the selection proposal if it
- // contains the clicked token.
- std::vector<int> used_tokens(tokens.size(), 0);
- for (auto span_result : proposals) {
- TokenSpan span = CodepointSpanToTokenSpan(tokens, span_result.first);
- if (span.first != kInvalidIndex && span.second != kInvalidIndex) {
- bool feasible = true;
- for (int i = span.first; i < span.second; i++) {
- if (used_tokens[i] != 0) {
- feasible = false;
- break;
- }
- }
-
- if (feasible) {
- if (span.first <= click_index && span.second > click_index) {
- return {span_result.first.first, span_result.first.second};
- }
- for (int i = span.first; i < span.second; i++) {
- used_tokens[i] = 1;
- }
- }
- }
- }
-
- return {click_indices.first, click_indices.second};
+ return click_indices;
}
std::vector<std::pair<std::string, float>>
@@ -459,9 +404,9 @@
}
if (std::get<0>(selection_indices) >= std::get<1>(selection_indices)) {
- TC_LOG(ERROR) << "Trying to run ClassifyText with invalid indices: "
- << std::get<0>(selection_indices) << " "
- << std::get<1>(selection_indices);
+ TC_VLOG(1) << "Trying to run ClassifyText with invalid indices: "
+ << std::get<0>(selection_indices) << " "
+ << std::get<1>(selection_indices);
return {};
}
@@ -475,21 +420,29 @@
return {{kEmailHintCollection, 1.0}};
}
+ // Check whether any of the regular expressions match.
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
+ for (const CompiledRegexPattern& regex_pattern : regex_patterns_) {
+ if (MatchesRegex(regex_pattern.pattern.get(), context)) {
+ return {{regex_pattern.collection_name, 1.0}};
+ }
+ }
+#endif
+
EmbeddingNetwork::Vector scores =
InferInternal(context, selection_indices, *sharing_feature_processor_,
*sharing_network_, sharing_feature_fn_, nullptr);
if (scores.empty() ||
scores.size() != sharing_feature_processor_->NumCollections()) {
- TC_LOG(ERROR) << "Using default class: scores.size() = " << scores.size();
+ TC_VLOG(1) << "Using default class: scores.size() = " << scores.size();
return {};
}
scores = nlp_core::ComputeSoftmax(scores);
- std::vector<std::pair<std::string, float>> result;
+ std::vector<std::pair<std::string, float>> result(scores.size());
for (int i = 0; i < scores.size(); i++) {
- result.push_back(
- {sharing_feature_processor_->LabelToCollection(i), scores[i]});
+ result[i] = {sharing_feature_processor_->LabelToCollection(i), scores[i]};
}
std::sort(result.begin(), result.end(),
[](const std::pair<std::string, float>& a,
@@ -509,4 +462,147 @@
return result;
}
+std::vector<CodepointSpan> TextClassificationModel::Chunk(
+ const std::string& context, CodepointSpan click_span,
+ TokenSpan relative_click_span) const {
+ std::unique_ptr<CachedFeatures> cached_features;
+ std::vector<Token> tokens;
+ int click_index;
+
+ int embedding_size = selection_network_->EmbeddingSize(0);
+ // TODO(zilka): Refactor the ExtractFeatures API to smoothly support the
+ // different usecases. Now it's a lot click-centric.
+ if (!selection_feature_processor_->ExtractFeatures(
+ context, click_span, relative_click_span, selection_feature_fn_,
+ embedding_size + selection_feature_processor_->DenseFeaturesCount(),
+ &tokens, &click_index, &cached_features)) {
+ TC_VLOG(1) << "Couldn't ExtractFeatures.";
+ return {};
+ }
+
+ if (relative_click_span == std::make_pair(kInvalidIndex, kInvalidIndex)) {
+ relative_click_span = {tokens.size() - 1, tokens.size() - 1};
+ }
+
+ struct SelectionProposal {
+ int label;
+ int token_index;
+ CodepointSpan span;
+ float score;
+ };
+
+ // Scan in the symmetry context for selection span proposals.
+ std::vector<SelectionProposal> proposals;
+
+ for (int i = -relative_click_span.first; i < relative_click_span.second + 1;
+ ++i) {
+ const int token_index = click_index + i;
+ if (token_index >= 0 && token_index < tokens.size() &&
+ !tokens[token_index].is_padding) {
+ float score;
+ VectorSpan<float> features;
+ VectorSpan<Token> output_tokens;
+
+ if (tokens[token_index].is_padding) {
+ continue;
+ }
+
+ std::vector<CodepointSpan> selection_label_spans;
+ CodepointSpan span;
+ if (cached_features->Get(token_index, &features, &output_tokens) &&
+ selection_feature_processor_->SelectionLabelSpans(
+ output_tokens, &selection_label_spans)) {
+ // Add an implicit proposal for each token to be by itself. Every
+ // token should be now represented in the results.
+ proposals.push_back(
+ SelectionProposal{0, token_index, selection_label_spans[0], 0.0});
+
+ std::vector<float> scores;
+ selection_network_->ComputeLogits(features, &scores);
+
+ scores = nlp_core::ComputeSoftmax(scores);
+ std::tie(span, score) = BestSelectionSpan(
+ {kInvalidIndex, kInvalidIndex}, scores, selection_label_spans);
+ if (span.first != kInvalidIndex && span.second != kInvalidIndex &&
+ score >= 0) {
+ const int prediction = BestPrediction(scores);
+ proposals.push_back(
+ SelectionProposal{prediction, token_index, span, score});
+ }
+ } else {
+ // Add an implicit proposal for each token to be by itself. Every token
+ // should be now represented in the results.
+ proposals.push_back(SelectionProposal{
+ 0,
+ token_index,
+ {tokens[token_index].start, tokens[token_index].end},
+ 0.0});
+ }
+ }
+ }
+
+ // Sort selection span proposals by their respective probabilities.
+ std::sort(proposals.begin(), proposals.end(),
+ [](const SelectionProposal& a, const SelectionProposal& b) {
+ return a.score > b.score;
+ });
+
+ // Go from the highest-scoring proposal and claim tokens. Tokens are marked as
+ // claimed by the higher-scoring selection proposals, so that the
+ // lower-scoring ones cannot use them. Returns the selection proposal if it
+ // contains the clicked token.
+ std::vector<CodepointSpan> result;
+ std::vector<bool> token_used(tokens.size(), false);
+ for (const SelectionProposal& proposal : proposals) {
+ const int predicted_label = proposal.label;
+ TokenSpan relative_span;
+ if (!selection_feature_processor_->LabelToTokenSpan(predicted_label,
+ &relative_span)) {
+ continue;
+ }
+ TokenSpan span;
+ span.first = proposal.token_index - relative_span.first;
+ span.second = proposal.token_index + relative_span.second + 1;
+
+ if (span.first != kInvalidIndex && span.second != kInvalidIndex) {
+ bool feasible = true;
+ for (int i = span.first; i < span.second; i++) {
+ if (token_used[i]) {
+ feasible = false;
+ break;
+ }
+ }
+
+ if (feasible) {
+ result.push_back(proposal.span);
+ for (int i = span.first; i < span.second; i++) {
+ token_used[i] = true;
+ }
+ }
+ }
+ }
+
+ std::sort(result.begin(), result.end(),
+ [](const CodepointSpan& a, const CodepointSpan& b) {
+ return a.first < b.first;
+ });
+
+ return result;
+}
+
+std::vector<TextClassificationModel::AnnotatedSpan>
+TextClassificationModel::Annotate(const std::string& context) const {
+ std::vector<CodepointSpan> chunks =
+ Chunk(context, /*click_span=*/{0, 1},
+ /*relative_click_span=*/{kInvalidIndex, kInvalidIndex});
+
+ std::vector<TextClassificationModel::AnnotatedSpan> result;
+ for (const CodepointSpan& chunk : chunks) {
+ result.emplace_back();
+ result.back().span = chunk;
+ result.back().classification = ClassifyText(context, chunk);
+ }
+ return result;
+}
+
} // namespace libtextclassifier
diff --git a/smartselect/text-classification-model.h b/smartselect/text-classification-model.h
index 522372c..5b58d89 100644
--- a/smartselect/text-classification-model.h
+++ b/smartselect/text-classification-model.h
@@ -23,7 +23,6 @@
#include <set>
#include <string>
-#include "base.h"
#include "common/embedding-network.h"
#include "common/feature-extractor.h"
#include "common/memory_image/embedding-network-params-from-image.h"
@@ -38,10 +37,33 @@
// SmartSelection/Sharing feed-forward model.
class TextClassificationModel {
public:
+ // Represents a result of Annotate call.
+ struct AnnotatedSpan {
+ // Unicode codepoint indices in the input string.
+ CodepointSpan span = {kInvalidIndex, kInvalidIndex};
+
+ // Classification result for the span.
+ std::vector<std::pair<std::string, float>> classification;
+ };
+
// Loads TextClassificationModel from given file given by an int
// file descriptor.
+ // Offset is byte a position in the file to the beginning of the model data.
+ TextClassificationModel(int fd, int offset, int size);
+
+ // Same as above but the whole file is mapped and it is assumed the model
+ // starts at offset 0.
explicit TextClassificationModel(int fd);
+ // Loads TextClassificationModel from given file.
+ explicit TextClassificationModel(const std::string& path);
+
+ // Loads TextClassificationModel from given location in memory.
+ TextClassificationModel(const void* addr, int size);
+
+ // Returns true if the model is ready for use.
+ bool IsInitialized() { return initialized_; }
+
// Bit flags for the input selection.
enum SelectionInputFlags { SELECTION_IS_URL = 0x1, SELECTION_IS_EMAIL = 0x2 };
@@ -63,11 +85,26 @@
const std::string& context, CodepointSpan click_indices,
int input_flags = 0) const;
+ // Annotates given input text. The annotations should cover the whole input
+ // context except for whitespaces, and are sorted by their position in the
+ // context string.
+ std::vector<AnnotatedSpan> Annotate(const std::string& context) const;
+
protected:
- // Removes punctuation from the beginning and end of the selection and returns
- // the new selection span.
- CodepointSpan StripPunctuation(CodepointSpan selection,
- const std::string& context) const;
+ // Initializes the model from mmap_ file.
+ void InitFromMmap();
+
+ // Extracts chunks from the context. The extraction proceeds from the center
+ // token determined by click_span and looks at relative_click_span tokens
+ // left and right around the click position.
+ // If relative_click_span == {kInvalidIndex, kInvalidIndex} then the whole
+ // context is considered, regardless of the click_span (which should point to
+ // the beginning {0, 1}.
+ // Returns the chunks sorted by their position in the context string.
+ // TODO(zilka): Tidy up the interface.
+ std::vector<CodepointSpan> Chunk(const std::string& context,
+ CodepointSpan click_span,
+ TokenSpan relative_click_span) const;
// During evaluation we need access to the feature processor.
FeatureProcessor* SelectionFeatureProcessor() const {
@@ -90,7 +127,14 @@
SharingModelOptions sharing_options_;
private:
- bool LoadModels(const nlp_core::MmapHandle& mmap_handle);
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
+ struct CompiledRegexPattern {
+ std::string collection_name;
+ std::unique_ptr<icu::RegexPattern> pattern;
+ };
+#endif
+
+ bool LoadModels(const void* addr, int size);
nlp_core::EmbeddingNetwork::Vector InferInternal(
const std::string& context, CodepointSpan span,
@@ -108,8 +152,8 @@
CodepointSpan SuggestSelectionSymmetrical(const std::string& context,
CodepointSpan click_indices) const;
- bool initialized_;
- nlp_core::ScopedMmap mmap_;
+ bool initialized_ = false;
+ std::unique_ptr<nlp_core::ScopedMmap> mmap_;
std::unique_ptr<ModelParams> selection_params_;
std::unique_ptr<FeatureProcessor> selection_feature_processor_;
std::unique_ptr<nlp_core::EmbeddingNetwork> selection_network_;
@@ -118,14 +162,28 @@
std::unique_ptr<ModelParams> sharing_params_;
std::unique_ptr<nlp_core::EmbeddingNetwork> sharing_network_;
FeatureVectorFn sharing_feature_fn_;
-
- std::set<int> punctuation_to_strip_;
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
+ std::vector<CompiledRegexPattern> regex_patterns_;
+#endif
};
// Parses the merged image given as a file descriptor, and reads
// the ModelOptions proto from the selection model.
bool ReadSelectionModelOptions(int fd, ModelOptions* model_options);
+// Pretty-printing function for TextClassificationModel::AnnotatedSpan.
+inline std::ostream& operator<<(
+ std::ostream& os, const TextClassificationModel::AnnotatedSpan& span) {
+ std::string best_class;
+ float best_score = -1;
+ if (!span.classification.empty()) {
+ best_class = span.classification[0].first;
+ best_score = span.classification[0].second;
+ }
+ return os << "Span(" << span.span.first << ", " << span.span.second << ", "
+ << best_class << ", " << best_score << ")";
+}
+
} // namespace libtextclassifier
#endif // LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_
diff --git a/smartselect/text-classification-model.proto b/smartselect/text-classification-model.proto
index b5b0287..ca10a0e 100644
--- a/smartselect/text-classification-model.proto
+++ b/smartselect/text-classification-model.proto
@@ -37,10 +37,7 @@
message SelectionModelOptions {
// A list of Unicode codepoints to strip from predicted selections.
- repeated int32 punctuation_to_strip = 1;
-
- // Whether to strip punctuation after the selection is made.
- optional bool strip_punctuation = 2;
+ repeated int32 deprecated_punctuation_to_strip = 1;
// Enforce symmetrical selections.
optional bool enforce_symmetry = 3;
@@ -48,6 +45,8 @@
// Number of inferences made around the click position (to one side), for
// enforcing symmetry.
optional int32 symmetry_context_size = 4;
+
+ reserved 2;
}
message SharingModelOptions {
@@ -60,8 +59,19 @@
// Limits for phone numbers.
optional int32 phone_min_num_digits = 3 [default = 7];
optional int32 phone_max_num_digits = 4 [default = 15];
+
+ // List of regular expression matchers to check.
+ message RegexPattern {
+ // The name of the collection of a match.
+ optional string collection_name = 1;
+
+ // The pattern to check.
+ optional string pattern = 2;
+ }
+ repeated RegexPattern regex_pattern = 5;
}
+// Next ID: 39
message FeatureProcessorOptions {
// Number of buckets used for hashing charactergrams.
optional int32 num_buckets = 1 [default = -1];
@@ -193,7 +203,18 @@
[default = INTERNAL_TOKENIZER];
optional bool icu_preserve_whitespace_tokens = 31 [default = false];
- reserved 7, 11, 12, 17, 26, 27, 28, 29, 32;
+ // List of codepoints that will be stripped from beginning and end of
+ // predicted spans.
+ repeated int32 ignored_span_boundary_codepoints = 36;
+
+ reserved 7, 11, 12, 26, 27, 28, 29, 32, 35;
+
+ // List of allowed charactergrams. The extracted charactergrams are filtered
+ // using this list, and charactergrams that are not present are interpreted as
+ // out-of-vocabulary.
+ // If no allowed_chargrams are specified, all charactergrams are allowed.
+ // The field is typed as bytes type to allow non-UTF8 chargrams.
+ repeated bytes allowed_chargrams = 38;
};
extend nlp_core.EmbeddingNetworkProto {
diff --git a/tests/text-classification-model_test.cc b/smartselect/text-classification-model_test.cc
similarity index 89%
rename from tests/text-classification-model_test.cc
rename to smartselect/text-classification-model_test.cc
index ed00876..490b395 100644
--- a/tests/text-classification-model_test.cc
+++ b/smartselect/text-classification-model_test.cc
@@ -21,7 +21,6 @@
#include <memory>
#include <string>
-#include "base.h"
#include "gtest/gtest.h"
namespace libtextclassifier {
@@ -140,37 +139,12 @@
explicit TestingTextClassificationModel(int fd)
: libtextclassifier::TextClassificationModel(fd) {}
- using libtextclassifier::TextClassificationModel::StripPunctuation;
-
void DisableClassificationHints() {
sharing_options_.set_always_accept_url_hint(false);
sharing_options_.set_always_accept_email_hint(false);
}
};
-TEST(TextClassificationModelTest, StripPunctuation) {
- const std::string model_path = GetModelPath();
- int fd = open(model_path.c_str(), O_RDONLY);
- std::unique_ptr<TestingTextClassificationModel> model(
- new TestingTextClassificationModel(fd));
- close(fd);
-
- EXPECT_EQ(std::make_pair(3, 10),
- model->StripPunctuation({0, 10}, ".,-abcd.()"));
- EXPECT_EQ(std::make_pair(0, 6), model->StripPunctuation({0, 6}, "(abcd)"));
- EXPECT_EQ(std::make_pair(1, 5), model->StripPunctuation({0, 6}, "[abcd]"));
- EXPECT_EQ(std::make_pair(1, 5), model->StripPunctuation({0, 6}, "{abcd}"));
-
- // Empty result.
- EXPECT_EQ(std::make_pair(0, 0), model->StripPunctuation({0, 1}, "&"));
- EXPECT_EQ(std::make_pair(0, 0), model->StripPunctuation({0, 4}, "&-,}"));
-
- // Invalid indices
- EXPECT_EQ(std::make_pair(-1, 523), model->StripPunctuation({-1, 523}, "a"));
- EXPECT_EQ(std::make_pair(-1, -1), model->StripPunctuation({-1, -1}, "a"));
- EXPECT_EQ(std::make_pair(0, -1), model->StripPunctuation({0, -1}, "a"));
-}
-
TEST(TextClassificationModelTest, SuggestSelectionNoCrashWithJunk) {
const std::string model_path = GetModelPath();
int fd = open(model_path.c_str(), O_RDONLY);
@@ -328,5 +302,46 @@
"phone: (123) 456 789,0001112", {7, 28}, 0)));
}
+TEST(TextClassificationModelTest, Annotate) {
+ const std::string model_path = GetModelPath();
+ int fd = open(model_path.c_str(), O_RDONLY);
+ std::unique_ptr<TestingTextClassificationModel> model(
+ new TestingTextClassificationModel(fd));
+ close(fd);
+
+ std::string test_string =
+ "I saw Barak Obama today at 350 Third Street, Cambridge";
+ std::vector<TextClassificationModel::AnnotatedSpan> result =
+ model->Annotate(test_string);
+
+ std::vector<TextClassificationModel::AnnotatedSpan> expected;
+ expected.emplace_back();
+ expected.back().span = {0, 1};
+ expected.back().classification.push_back({"other", 1.0});
+ expected.emplace_back();
+ expected.back().span = {2, 5};
+ expected.back().classification.push_back({"other", 1.0});
+ expected.emplace_back();
+ expected.back().span = {6, 17};
+ expected.back().classification.push_back({"other", 1.0});
+ expected.emplace_back();
+ expected.back().span = {18, 23};
+ expected.back().classification.push_back({"other", 1.0});
+ expected.emplace_back();
+ expected.back().span = {24, 26};
+ expected.back().classification.push_back({"other", 1.0});
+ expected.emplace_back();
+ expected.back().span = {27, 54};
+ expected.back().classification.push_back({"address", 1.0});
+
+ ASSERT_EQ(result.size(), expected.size());
+ for (int i = 0; i < expected.size(); ++i) {
+ EXPECT_EQ(result[i].span, expected[i].span) << result[i];
+ EXPECT_EQ(result[i].classification[0].first,
+ expected[i].classification[0].first)
+ << result[i];
+ }
+}
+
} // namespace
} // namespace libtextclassifier
diff --git a/smartselect/token-feature-extractor.cc b/smartselect/token-feature-extractor.cc
index 479be41..6afd951 100644
--- a/smartselect/token-feature-extractor.cc
+++ b/smartselect/token-feature-extractor.cc
@@ -16,14 +16,17 @@
#include "smartselect/token-feature-extractor.h"
+#include <cctype>
#include <string>
#include "util/base/logging.h"
#include "util/hash/farmhash.h"
#include "util/strings/stringpiece.h"
#include "util/utf8/unicodetext.h"
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
#include "unicode/regex.h"
#include "unicode/uchar.h"
+#endif
namespace libtextclassifier {
@@ -47,6 +50,7 @@
return copy;
}
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
void RemapTokenUnicode(const std::string& token,
const TokenFeatureExtractorOptions& options,
UnicodeText* remapped) {
@@ -70,12 +74,14 @@
icu_string.toUTF8String(utf8_str);
remapped->CopyUTF8(utf8_str.data(), utf8_str.length());
}
+#endif
} // namespace
TokenFeatureExtractor::TokenFeatureExtractor(
const TokenFeatureExtractorOptions& options)
: options_(options) {
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
UErrorCode status;
for (const std::string& pattern : options.regexp_features) {
status = U_ZERO_ERROR;
@@ -87,10 +93,44 @@
TC_LOG(WARNING) << "Failed to load pattern" << pattern;
}
}
+#else
+ bool found_unsupported_regexp_features = false;
+ for (const std::string& pattern : options.regexp_features) {
+ // A temporary solution to support this specific regexp pattern without
+ // adding too much binary size.
+ if (pattern == "^[^a-z]*$") {
+ enable_all_caps_feature_ = true;
+ } else {
+ found_unsupported_regexp_features = true;
+ }
+ }
+ if (found_unsupported_regexp_features) {
+ TC_LOG(WARNING) << "ICU not supported regexp features ignored.";
+ }
+#endif
}
int TokenFeatureExtractor::HashToken(StringPiece token) const {
- return tcfarmhash::Fingerprint64(token) % options_.num_buckets;
+ if (options_.allowed_chargrams.empty()) {
+ return tcfarmhash::Fingerprint64(token) % options_.num_buckets;
+ } else {
+ // Padding and out-of-vocabulary tokens have extra buckets reserved because
+ // they are special and important tokens, and we don't want them to share
+ // embedding with other charactergrams.
+ // TODO(zilka): Experimentally verify.
+ const int kNumExtraBuckets = 2;
+ const std::string token_string = token.ToString();
+ if (token_string == "<PAD>") {
+ return 1;
+ } else if (options_.allowed_chargrams.find(token_string) ==
+ options_.allowed_chargrams.end()) {
+ return 0; // Out-of-vocabulary.
+ } else {
+ return (tcfarmhash::Fingerprint64(token) %
+ (options_.num_buckets - kNumExtraBuckets)) +
+ kNumExtraBuckets;
+ }
+ }
}
std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeatures(
@@ -126,19 +166,23 @@
// Upper-bound the number of charactergram extracted to avoid resizing.
result.reserve(options_.chargram_orders.size() * feature_word.size());
- // Generate the character-grams.
- for (int chargram_order : options_.chargram_orders) {
- if (chargram_order == 1) {
- for (int i = 1; i < feature_word.size() - 1; ++i) {
- result.push_back(
- HashToken(StringPiece(feature_word, /*offset=*/i, /*len=*/1)));
- }
- } else {
- for (int i = 0;
- i < static_cast<int>(feature_word.size()) - chargram_order + 1;
- ++i) {
- result.push_back(HashToken(
- StringPiece(feature_word, /*offset=*/i, /*len=*/chargram_order)));
+ if (options_.chargram_orders.empty()) {
+ result.push_back(HashToken(feature_word));
+ } else {
+ // Generate the character-grams.
+ for (int chargram_order : options_.chargram_orders) {
+ if (chargram_order == 1) {
+ for (int i = 1; i < feature_word.size() - 1; ++i) {
+ result.push_back(
+ HashToken(StringPiece(feature_word, /*offset=*/i, /*len=*/1)));
+ }
+ } else {
+ for (int i = 0;
+ i < static_cast<int>(feature_word.size()) - chargram_order + 1;
+ ++i) {
+ result.push_back(HashToken(StringPiece(feature_word, /*offset=*/i,
+ /*len=*/chargram_order)));
+ }
}
}
}
@@ -148,6 +192,7 @@
std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesUnicode(
const Token& token) const {
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
std::vector<int> result;
if (token.is_padding || token.value.empty()) {
result.push_back(HashToken("<PAD>"));
@@ -186,39 +231,47 @@
// Upper-bound the number of charactergram extracted to avoid resizing.
result.reserve(options_.chargram_orders.size() * feature_word.size());
- // Generate the character-grams.
- for (int chargram_order : options_.chargram_orders) {
- UnicodeText::const_iterator it_start = feature_word_unicode.begin();
- UnicodeText::const_iterator it_end = feature_word_unicode.end();
- if (chargram_order == 1) {
- ++it_start;
- --it_end;
- }
-
- UnicodeText::const_iterator it_chargram_start = it_start;
- UnicodeText::const_iterator it_chargram_end = it_start;
- bool chargram_is_complete = true;
- for (int i = 0; i < chargram_order; ++i) {
- if (it_chargram_end == it_end) {
- chargram_is_complete = false;
- break;
+ if (options_.chargram_orders.empty()) {
+ result.push_back(HashToken(feature_word));
+ } else {
+ // Generate the character-grams.
+ for (int chargram_order : options_.chargram_orders) {
+ UnicodeText::const_iterator it_start = feature_word_unicode.begin();
+ UnicodeText::const_iterator it_end = feature_word_unicode.end();
+ if (chargram_order == 1) {
+ ++it_start;
+ --it_end;
}
- ++it_chargram_end;
- }
- if (!chargram_is_complete) {
- continue;
- }
- for (; it_chargram_end <= it_end;
- ++it_chargram_start, ++it_chargram_end) {
- const int length_bytes =
- it_chargram_end.utf8_data() - it_chargram_start.utf8_data();
- result.push_back(HashToken(
- StringPiece(it_chargram_start.utf8_data(), length_bytes)));
+ UnicodeText::const_iterator it_chargram_start = it_start;
+ UnicodeText::const_iterator it_chargram_end = it_start;
+ bool chargram_is_complete = true;
+ for (int i = 0; i < chargram_order; ++i) {
+ if (it_chargram_end == it_end) {
+ chargram_is_complete = false;
+ break;
+ }
+ ++it_chargram_end;
+ }
+ if (!chargram_is_complete) {
+ continue;
+ }
+
+ for (; it_chargram_end <= it_end;
+ ++it_chargram_start, ++it_chargram_end) {
+ const int length_bytes =
+ it_chargram_end.utf8_data() - it_chargram_start.utf8_data();
+ result.push_back(HashToken(
+ StringPiece(it_chargram_start.utf8_data(), length_bytes)));
+ }
}
}
}
return result;
+#else
+ TC_LOG(WARNING) << "ICU not supported. No feature extracted.";
+ return {};
+#endif
}
bool TokenFeatureExtractor::Extract(const Token& token, bool is_in_span,
@@ -234,7 +287,14 @@
if (options_.unicode_aware_features) {
UnicodeText token_unicode =
UTF8ToUnicodeText(token.value, /*do_copy=*/false);
- if (!token.value.empty() && u_isupper(*token_unicode.begin())) {
+ bool is_upper;
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
+ is_upper = u_isupper(*token_unicode.begin());
+#else
+ TC_LOG(WARNING) << "Using non-unicode isupper because ICU is disabled.";
+ is_upper = isupper(*token_unicode.begin());
+#endif
+ if (!token.value.empty() && is_upper) {
dense_features->push_back(1.0);
} else {
dense_features->push_back(-1.0);
@@ -260,6 +320,7 @@
}
}
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
// Add regexp features.
if (!regex_patterns_.empty()) {
icu::UnicodeString unicode_str(token.value.c_str(), token.value.size(),
@@ -281,6 +342,23 @@
}
}
}
+#else
+ if (enable_all_caps_feature_) {
+ bool is_all_caps = true;
+ for (const char character_byte : token.value) {
+ if (islower(character_byte)) {
+ is_all_caps = false;
+ break;
+ }
+ }
+ if (is_all_caps) {
+ dense_features->push_back(1.0);
+ } else {
+ dense_features->push_back(-1.0);
+ }
+ }
+#endif
+
return true;
}
diff --git a/smartselect/token-feature-extractor.h b/smartselect/token-feature-extractor.h
index 8287fbd..5afeca4 100644
--- a/smartselect/token-feature-extractor.h
+++ b/smartselect/token-feature-extractor.h
@@ -18,12 +18,14 @@
#define LIBTEXTCLASSIFIER_SMARTSELECT_TOKEN_FEATURE_EXTRACTOR_H_
#include <memory>
+#include <unordered_set>
#include <vector>
-#include "base.h"
#include "smartselect/types.h"
#include "util/strings/stringpiece.h"
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
#include "unicode/regex.h"
+#endif
namespace libtextclassifier {
@@ -55,6 +57,12 @@
// Maximum length of a word.
int max_word_length = 20;
+
+ // List of allowed charactergrams. The extracted charactergrams are filtered
+ // using this list, and charactergrams that are not present are interpreted as
+ // out-of-vocabulary.
+ // If no allowed_chargrams are specified, all charactergrams are allowed.
+ std::unordered_set<std::string> allowed_chargrams;
};
class TokenFeatureExtractor {
@@ -73,8 +81,16 @@
std::vector<float>* dense_features) const;
int DenseFeaturesCount() const {
- return options_.extract_case_feature +
- options_.extract_selection_mask_feature + regex_patterns_.size();
+ int feature_count =
+ options_.extract_case_feature + options_.extract_selection_mask_feature;
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
+ feature_count += regex_patterns_.size();
+#else
+ if (enable_all_caps_feature_) {
+ feature_count += 1;
+ }
+#endif
+ return feature_count;
}
protected:
@@ -94,8 +110,11 @@
private:
TokenFeatureExtractorOptions options_;
-
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
std::vector<std::unique_ptr<icu::RegexPattern>> regex_patterns_;
+#else
+ bool enable_all_caps_feature_ = false;
+#endif
};
} // namespace libtextclassifier
diff --git a/tests/token-feature-extractor_test.cc b/smartselect/token-feature-extractor_test.cc
similarity index 76%
rename from tests/token-feature-extractor_test.cc
rename to smartselect/token-feature-extractor_test.cc
index c85ba50..4b635fd 100644
--- a/tests/token-feature-extractor_test.cc
+++ b/smartselect/token-feature-extractor_test.cc
@@ -98,6 +98,35 @@
EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
}
+TEST(TokenFeatureExtractorTest, ExtractAsciiNoChargrams) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = false;
+ options.extract_selection_mask_feature = true;
+ TestingTokenFeatureExtractor extractor(options);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+
+ extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({extractor.HashToken("^Hello$")}));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
+
+ sparse_features.clear();
+ dense_features.clear();
+ extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({extractor.HashToken("^world!$")}));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
+}
+
TEST(TokenFeatureExtractorTest, ExtractUnicode) {
TokenFeatureExtractorOptions options;
options.num_buckets = 1000;
@@ -168,6 +197,36 @@
EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
}
+TEST(TokenFeatureExtractorTest, ExtractUnicodeNoChargrams) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = true;
+ options.extract_selection_mask_feature = true;
+ TestingTokenFeatureExtractor extractor(options);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+
+ extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({extractor.HashToken("^Hělló$")}));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
+
+ sparse_features.clear();
+ dense_features.clear();
+ extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features, testing::ElementsAreArray({
+ extractor.HashToken("^world!$"),
+ }));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
+}
+
TEST(TokenFeatureExtractorTest, ICUCaseFeature) {
TokenFeatureExtractorOptions options;
options.num_buckets = 1000;
@@ -400,5 +459,85 @@
EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
}
+TEST(TokenFeatureExtractorTest, ExtractFiltered) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2, 3};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = false;
+ options.extract_selection_mask_feature = true;
+ options.allowed_chargrams.insert("^H");
+ options.allowed_chargrams.insert("ll");
+ options.allowed_chargrams.insert("llo");
+ options.allowed_chargrams.insert("w");
+ options.allowed_chargrams.insert("!");
+ options.allowed_chargrams.insert("\xc4"); // UTF8 control character.
+
+ TestingTokenFeatureExtractor extractor(options);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+
+ extractor.Extract(Token{"Hěllo", 0, 5}, true, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({
+ // clang-format off
+ 0,
+ extractor.HashToken("\xc4"),
+ 0,
+ 0,
+ 0,
+ 0,
+ extractor.HashToken("^H"),
+ 0,
+ 0,
+ 0,
+ extractor.HashToken("ll"),
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ extractor.HashToken("llo"),
+ 0
+ // clang-format on
+ }));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
+
+ sparse_features.clear();
+ dense_features.clear();
+ extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features, testing::ElementsAreArray({
+ // clang-format off
+ extractor.HashToken("w"),
+ 0,
+ 0,
+ 0,
+ 0,
+ extractor.HashToken("!"),
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ // clang-format on
+ }));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
+ EXPECT_EQ(extractor.HashToken("<PAD>"), 1);
+}
+
} // namespace
} // namespace libtextclassifier
diff --git a/smartselect/tokenizer.cc b/smartselect/tokenizer.cc
index 2093fde..2489a61 100644
--- a/smartselect/tokenizer.cc
+++ b/smartselect/tokenizer.cc
@@ -16,49 +16,42 @@
#include "smartselect/tokenizer.h"
+#include <algorithm>
+
#include "util/strings/utf8.h"
#include "util/utf8/unicodetext.h"
namespace libtextclassifier {
-void Tokenizer::PrepareTokenizationCodepointRanges(
- const std::vector<TokenizationCodepointRange>& codepoint_range_configs) {
- codepoint_ranges_.clear();
- codepoint_ranges_.reserve(codepoint_range_configs.size());
- for (const TokenizationCodepointRange& range : codepoint_range_configs) {
- codepoint_ranges_.push_back(
- CodepointRange(range.start(), range.end(), range.role()));
- }
-
+Tokenizer::Tokenizer(
+ const std::vector<TokenizationCodepointRange>& codepoint_ranges)
+ : codepoint_ranges_(codepoint_ranges) {
std::sort(codepoint_ranges_.begin(), codepoint_ranges_.end(),
- [](const CodepointRange& a, const CodepointRange& b) {
- return a.start < b.start;
+ [](const TokenizationCodepointRange& a,
+ const TokenizationCodepointRange& b) {
+ return a.start() < b.start();
});
}
TokenizationCodepointRange::Role Tokenizer::FindTokenizationRole(
int codepoint) const {
- auto it = std::lower_bound(codepoint_ranges_.begin(), codepoint_ranges_.end(),
- codepoint,
- [](const CodepointRange& range, int codepoint) {
- // This function compares range with the
- // codepoint for the purpose of finding the first
- // greater or equal range. Because of the use of
- // std::lower_bound it needs to return true when
- // range < codepoint; the first time it will
- // return false the lower bound is found and
- // returned.
- //
- // It might seem weird that the condition is
- // range.end <= codepoint here but when codepoint
- // == range.end it means it's actually just
- // outside of the range, thus the range is less
- // than the codepoint.
- return range.end <= codepoint;
- });
- if (it != codepoint_ranges_.end() && it->start <= codepoint &&
- it->end > codepoint) {
- return it->role;
+ auto it = std::lower_bound(
+ codepoint_ranges_.begin(), codepoint_ranges_.end(), codepoint,
+ [](const TokenizationCodepointRange& range, int codepoint) {
+ // This function compares range with the codepoint for the purpose of
+ // finding the first greater or equal range. Because of the use of
+ // std::lower_bound it needs to return true when range < codepoint;
+ // the first time it will return false the lower bound is found and
+ // returned.
+ //
+ // It might seem weird that the condition is range.end <= codepoint
+ // here but when codepoint == range.end it means it's actually just
+ // outside of the range, thus the range is less than the codepoint.
+ return range.end() <= codepoint;
+ });
+ if (it != codepoint_ranges_.end() && it->start() <= codepoint &&
+ it->end() > codepoint) {
+ return it->role();
} else {
return TokenizationCodepointRange::DEFAULT_ROLE;
}
diff --git a/smartselect/tokenizer.h b/smartselect/tokenizer.h
index 897f7c4..4eb78f9 100644
--- a/smartselect/tokenizer.h
+++ b/smartselect/tokenizer.h
@@ -22,7 +22,6 @@
#include "smartselect/tokenizer.pb.h"
#include "smartselect/types.h"
-#include "util/base/integral_types.h"
namespace libtextclassifier {
@@ -31,29 +30,12 @@
class Tokenizer {
public:
explicit Tokenizer(
- const std::vector<TokenizationCodepointRange>& codepoint_range_configs) {
- PrepareTokenizationCodepointRanges(codepoint_range_configs);
- }
+ const std::vector<TokenizationCodepointRange>& codepoint_ranges);
// Tokenizes the input string using the selected tokenization method.
std::vector<Token> Tokenize(const std::string& utf8_text) const;
protected:
- // Represents a codepoint range [start, end) with its role for tokenization.
- struct CodepointRange {
- int32 start;
- int32 end;
- TokenizationCodepointRange::Role role;
-
- CodepointRange(int32 arg_start, int32 arg_end,
- TokenizationCodepointRange::Role arg_role)
- : start(arg_start), end(arg_end), role(arg_role) {}
- };
-
- // Prepares tokenization codepoint ranges for use in tokenization.
- void PrepareTokenizationCodepointRanges(
- const std::vector<TokenizationCodepointRange>& codepoint_range_configs);
-
// Finds the tokenization role for given codepoint.
// If the character is not found returns DEFAULT_ROLE.
// Internally uses binary search so should be O(log(# of codepoint_ranges)).
@@ -62,7 +44,7 @@
private:
// Codepoint ranges that determine how different codepoints are tokenized.
// The ranges must not overlap.
- std::vector<CodepointRange> codepoint_ranges_;
+ std::vector<TokenizationCodepointRange> codepoint_ranges_;
};
} // namespace libtextclassifier
diff --git a/tests/tokenizer_test.cc b/smartselect/tokenizer_test.cc
similarity index 100%
rename from tests/tokenizer_test.cc
rename to smartselect/tokenizer_test.cc
diff --git a/textclassifier_jni.cc b/textclassifier_jni.cc
index 8d64d87..8740f4c 100644
--- a/textclassifier_jni.cc
+++ b/textclassifier_jni.cc
@@ -23,9 +23,10 @@
#include "lang_id/lang-id.h"
#include "smartselect/text-classification-model.h"
+#include "util/java/scoped_local_ref.h"
-using libtextclassifier::TextClassificationModel;
using libtextclassifier::ModelOptions;
+using libtextclassifier::TextClassificationModel;
using libtextclassifier::nlp_core::lang_id::LangId;
namespace {
@@ -38,6 +39,11 @@
}
jclass string_class = env->FindClass("java/lang/String");
+ if (!string_class) {
+ TC_LOG(ERROR) << "Can't find String class";
+ return false;
+ }
+
jmethodID get_bytes_id =
env->GetMethodID(string_class, "getBytes", "(Ljava/lang/String;)[B");
@@ -69,6 +75,11 @@
JNIEnv* env, const std::string& result_class_name,
const std::vector<std::pair<std::string, float>>& classification_result) {
jclass result_class = env->FindClass(result_class_name.c_str());
+ if (!result_class) {
+ TC_LOG(ERROR) << "Couldn't find result class: " << result_class_name;
+ return nullptr;
+ }
+
jmethodID result_class_constructor =
env->GetMethodID(result_class, "<init>", "(Ljava/lang/String;F)V");
@@ -155,22 +166,64 @@
} // namespace libtextclassifier
-using libtextclassifier::ConvertIndicesUTF8ToBMP;
-using libtextclassifier::ConvertIndicesBMPToUTF8;
using libtextclassifier::CodepointSpan;
+using libtextclassifier::ConvertIndicesBMPToUTF8;
+using libtextclassifier::ConvertIndicesUTF8ToBMP;
+using libtextclassifier::ScopedLocalRef;
-JNIEXPORT jlong JNICALL
-Java_android_view_textclassifier_SmartSelection_nativeNew(JNIEnv* env,
- jobject thiz,
- jint fd) {
+JNI_METHOD(jlong, SmartSelection, nativeNew)
+(JNIEnv* env, jobject thiz, jint fd) {
TextClassificationModel* model = new TextClassificationModel(fd);
return reinterpret_cast<jlong>(model);
}
-JNIEXPORT jintArray JNICALL
-Java_android_view_textclassifier_SmartSelection_nativeSuggest(
- JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
- jint selection_end) {
+JNI_METHOD(jlong, SmartSelection, nativeNewFromPath)
+(JNIEnv* env, jobject thiz, jstring path) {
+ const std::string path_str = ToStlString(env, path);
+ TextClassificationModel* model = new TextClassificationModel(path_str);
+ return reinterpret_cast<jlong>(model);
+}
+
+JNI_METHOD(jlong, SmartSelection, nativeNewFromAssetFileDescriptor)
+(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
+ // Get system-level file descriptor from AssetFileDescriptor.
+ ScopedLocalRef<jclass> afd_class(
+ env->FindClass("android/content/res/AssetFileDescriptor"), env);
+ if (afd_class == nullptr) {
+ TC_LOG(ERROR) << "Couln't find AssetFileDescriptor.";
+ return reinterpret_cast<jlong>(nullptr);
+ }
+ jmethodID afd_class_getFileDescriptor = env->GetMethodID(
+ afd_class.get(), "getFileDescriptor", "()Ljava/io/FileDescriptor;");
+ if (afd_class_getFileDescriptor == nullptr) {
+ TC_LOG(ERROR) << "Couln't find getFileDescriptor.";
+ return reinterpret_cast<jlong>(nullptr);
+ }
+
+ ScopedLocalRef<jclass> fd_class(env->FindClass("java/io/FileDescriptor"),
+ env);
+ if (fd_class == nullptr) {
+ TC_LOG(ERROR) << "Couln't find FileDescriptor.";
+ return reinterpret_cast<jlong>(nullptr);
+ }
+ jfieldID fd_class_descriptor =
+ env->GetFieldID(fd_class.get(), "descriptor", "I");
+ if (fd_class_descriptor == nullptr) {
+ TC_LOG(ERROR) << "Couln't find descriptor.";
+ return reinterpret_cast<jlong>(nullptr);
+ }
+
+ jobject bundle_jfd = env->CallObjectMethod(afd, afd_class_getFileDescriptor);
+ jint bundle_cfd = env->GetIntField(bundle_jfd, fd_class_descriptor);
+
+ TextClassificationModel* model =
+ new TextClassificationModel(bundle_cfd, offset, size);
+ return reinterpret_cast<jlong>(model);
+}
+
+JNI_METHOD(jintArray, SmartSelection, nativeSuggest)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
+ jint selection_end) {
TextClassificationModel* model =
reinterpret_cast<TextClassificationModel*>(ptr);
@@ -187,10 +240,9 @@
return result;
}
-JNIEXPORT jobjectArray JNICALL
-Java_android_view_textclassifier_SmartSelection_nativeClassifyText(
- JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
- jint selection_end, jint input_flags) {
+JNI_METHOD(jobjectArray, SmartSelection, nativeClassifyText)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
+ jint selection_end, jint input_flags) {
TextClassificationModel* ff_model =
reinterpret_cast<TextClassificationModel*>(ptr);
const std::vector<std::pair<std::string, float>> classification_result =
@@ -198,28 +250,58 @@
{selection_begin, selection_end}, input_flags);
return ScoredStringsToJObjectArray(
- env, "android/view/textclassifier/SmartSelection$ClassificationResult",
+ env, TC_PACKAGE_PATH "SmartSelection$ClassificationResult",
classification_result);
}
-JNIEXPORT void JNICALL
-Java_android_view_textclassifier_SmartSelection_nativeClose(JNIEnv* env,
- jobject thiz,
- jlong ptr) {
+JNI_METHOD(jobjectArray, SmartSelection, nativeAnnotate)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring context) {
+ TextClassificationModel* model =
+ reinterpret_cast<TextClassificationModel*>(ptr);
+ std::string context_utf8 = ToStlString(env, context);
+ std::vector<TextClassificationModel::AnnotatedSpan> annotations =
+ model->Annotate(context_utf8);
+
+ jclass result_class =
+ env->FindClass(TC_PACKAGE_PATH "SmartSelection$AnnotatedSpan");
+ if (!result_class) {
+ TC_LOG(ERROR) << "Couldn't find result class: "
+ << TC_PACKAGE_PATH "SmartSelection$AnnotatedSpan";
+ return nullptr;
+ }
+
+ jmethodID result_class_constructor = env->GetMethodID(
+ result_class, "<init>",
+ "(II[L" TC_PACKAGE_PATH "SmartSelection$ClassificationResult;)V");
+
+ jobjectArray results =
+ env->NewObjectArray(annotations.size(), result_class, nullptr);
+
+ for (int i = 0; i < annotations.size(); ++i) {
+ CodepointSpan span_bmp =
+ ConvertIndicesUTF8ToBMP(context_utf8, annotations[i].span);
+ jobject result = env->NewObject(
+ result_class, result_class_constructor,
+ static_cast<jint>(span_bmp.first), static_cast<jint>(span_bmp.second),
+ ScoredStringsToJObjectArray(
+ env, TC_PACKAGE_PATH "SmartSelection$ClassificationResult",
+ annotations[i].classification));
+ env->SetObjectArrayElement(results, i, result);
+ env->DeleteLocalRef(result);
+ }
+ env->DeleteLocalRef(result_class);
+ return results;
+}
+
+JNI_METHOD(void, SmartSelection, nativeClose)
+(JNIEnv* env, jobject thiz, jlong ptr) {
TextClassificationModel* model =
reinterpret_cast<TextClassificationModel*>(ptr);
delete model;
}
-JNIEXPORT jlong JNICALL Java_android_view_textclassifier_LangId_nativeNew(
- JNIEnv* env, jobject thiz, jint fd) {
- return reinterpret_cast<jlong>(new LangId(fd));
-}
-
-JNIEXPORT jstring JNICALL
-Java_android_view_textclassifier_SmartSelection_nativeGetLanguage(JNIEnv* env,
- jobject clazz,
- jint fd) {
+JNI_METHOD(jstring, SmartSelection, nativeGetLanguage)
+(JNIEnv* env, jobject clazz, jint fd) {
ModelOptions model_options;
if (ReadSelectionModelOptions(fd, &model_options)) {
return env->NewStringUTF(model_options.language().c_str());
@@ -228,10 +310,8 @@
}
}
-JNIEXPORT jint JNICALL
-Java_android_view_textclassifier_SmartSelection_nativeGetVersion(JNIEnv* env,
- jobject clazz,
- jint fd) {
+JNI_METHOD(jint, SmartSelection, nativeGetVersion)
+(JNIEnv* env, jobject clazz, jint fd) {
ModelOptions model_options;
if (ReadSelectionModelOptions(fd, &model_options)) {
return model_options.version();
@@ -240,28 +320,31 @@
}
}
-JNIEXPORT jobjectArray JNICALL
-Java_android_view_textclassifier_LangId_nativeFindLanguages(JNIEnv* env,
- jobject thiz,
- jlong ptr,
- jstring text) {
+#ifndef LIBTEXTCLASSIFIER_DISABLE_LANG_ID
+JNI_METHOD(jlong, LangId, nativeNew)
+(JNIEnv* env, jobject thiz, jint fd) {
+ return reinterpret_cast<jlong>(new LangId(fd));
+}
+
+JNI_METHOD(jobjectArray, LangId, nativeFindLanguages)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring text) {
LangId* lang_id = reinterpret_cast<LangId*>(ptr);
const std::vector<std::pair<std::string, float>> scored_languages =
lang_id->FindLanguages(ToStlString(env, text));
return ScoredStringsToJObjectArray(
- env, "android/view/textclassifier/LangId$ClassificationResult",
- scored_languages);
+ env, TC_PACKAGE_PATH "LangId$ClassificationResult", scored_languages);
}
-JNIEXPORT void JNICALL Java_android_view_textclassifier_LangId_nativeClose(
- JNIEnv* env, jobject thiz, jlong ptr) {
+JNI_METHOD(void, LangId, nativeClose)
+(JNIEnv* env, jobject thiz, jlong ptr) {
LangId* lang_id = reinterpret_cast<LangId*>(ptr);
delete lang_id;
}
-JNIEXPORT int JNICALL Java_android_view_textclassifier_LangId_nativeGetVersion(
- JNIEnv* env, jobject clazz, jint fd) {
+JNI_METHOD(int, LangId, nativeGetVersion)
+(JNIEnv* env, jobject clazz, jint fd) {
std::unique_ptr<LangId> lang_id(new LangId(fd));
return lang_id->version();
}
+#endif
diff --git a/textclassifier_jni.h b/textclassifier_jni.h
index 28bb444..1709ff4 100644
--- a/textclassifier_jni.h
+++ b/textclassifier_jni.h
@@ -22,56 +22,70 @@
#include "smartselect/types.h"
+#ifndef TC_PACKAGE_NAME
+#define TC_PACKAGE_NAME android_view_textclassifier
+#endif
+#ifndef TC_PACKAGE_PATH
+#define TC_PACKAGE_PATH "android/view/textclassifier/"
+#endif
+
+#define JNI_METHOD_PRIMITIVE(return_type, package_name, class_name, \
+ method_name) \
+ JNIEXPORT return_type JNICALL \
+ Java_##package_name##_##class_name##_##method_name
+
+// The indirection is needed to correctly expand the TC_PACKAGE_NAME macro.
+#define JNI_METHOD2(return_type, package_name, class_name, method_name) \
+ JNI_METHOD_PRIMITIVE(return_type, package_name, class_name, method_name)
+
+#define JNI_METHOD(return_type, class_name, method_name) \
+ JNI_METHOD2(return_type, TC_PACKAGE_NAME, class_name, method_name)
+
#ifdef __cplusplus
extern "C" {
#endif
// SmartSelection.
-JNIEXPORT jlong JNICALL
-Java_android_view_textclassifier_SmartSelection_nativeNew(JNIEnv* env,
- jobject thiz,
- jint fd);
+JNI_METHOD(jlong, SmartSelection, nativeNew)
+(JNIEnv* env, jobject thiz, jint fd);
-JNIEXPORT jintArray JNICALL
-Java_android_view_textclassifier_SmartSelection_nativeSuggest(
- JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
- jint selection_end);
+JNI_METHOD(jlong, SmartSelection, nativeNewFromPath)
+(JNIEnv* env, jobject thiz, jstring path);
-JNIEXPORT jobjectArray JNICALL
-Java_android_view_textclassifier_SmartSelection_nativeClassifyText(
- JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
- jint selection_end, jint input_flags);
+JNI_METHOD(jlong, SmartSelection, nativeNewFromAssetFileDescriptor)
+(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size);
-JNIEXPORT void JNICALL
-Java_android_view_textclassifier_SmartSelection_nativeClose(JNIEnv* env,
- jobject thiz,
- jlong ptr);
+JNI_METHOD(jintArray, SmartSelection, nativeSuggest)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
+ jint selection_end);
-JNIEXPORT jstring JNICALL
-Java_android_view_textclassifier_SmartSelection_nativeGetLanguage(JNIEnv* env,
- jobject clazz,
- jint fd);
+JNI_METHOD(jobjectArray, SmartSelection, nativeClassifyText)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
+ jint selection_end, jint input_flags);
-JNIEXPORT jint JNICALL
-Java_android_view_textclassifier_SmartSelection_nativeGetVersion(JNIEnv* env,
- jobject clazz,
- jint fd);
+JNI_METHOD(jobjectArray, SmartSelection, nativeAnnotate)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring context);
+JNI_METHOD(void, SmartSelection, nativeClose)
+(JNIEnv* env, jobject thiz, jlong ptr);
+
+JNI_METHOD(jstring, SmartSelection, nativeGetLanguage)
+(JNIEnv* env, jobject clazz, jint fd);
+
+JNI_METHOD(jint, SmartSelection, nativeGetVersion)
+(JNIEnv* env, jobject clazz, jint fd);
+
+#ifndef LIBTEXTCLASSIFIER_DISABLE_LANG_ID
// LangId.
-JNIEXPORT jlong JNICALL Java_android_view_textclassifier_LangId_nativeNew(
- JNIEnv* env, jobject thiz, jint fd);
+JNI_METHOD(jlong, LangId, nativeNew)(JNIEnv* env, jobject thiz, jint fd);
-JNIEXPORT jobjectArray JNICALL
-Java_android_view_textclassifier_LangId_nativeFindLanguages(JNIEnv* env,
- jobject thiz,
- jlong ptr,
- jstring text);
+JNI_METHOD(jobjectArray, LangId, nativeFindLanguages)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring text);
-JNIEXPORT void JNICALL Java_android_view_textclassifier_LangId_nativeClose(
- JNIEnv* env, jobject thiz, jlong ptr);
+JNI_METHOD(void, LangId, nativeClose)(JNIEnv* env, jobject thiz, jlong ptr);
-JNIEXPORT int JNICALL Java_android_view_textclassifier_LangId_nativeGetVersion(
- JNIEnv* env, jobject clazz, jint fd);
+JNI_METHOD(int, LangId, nativeGetVersion)(JNIEnv* env, jobject clazz, jint fd);
+#endif
#ifdef __cplusplus
}
diff --git a/tests/textclassifier_jni_test.cc b/textclassifier_jni_test.cc
similarity index 100%
rename from tests/textclassifier_jni_test.cc
rename to textclassifier_jni_test.cc
diff --git a/util/base/casts.h b/util/base/casts.h
index ad12ce4..805ee89 100644
--- a/util/base/casts.h
+++ b/util/base/casts.h
@@ -21,13 +21,12 @@
namespace libtextclassifier {
-// lang_id_bit_cast<Dest,Source> is a template function that implements the
-// equivalent of "*reinterpret_cast<Dest*>(&source)". We need this in
-// very low-level functions like the protobuf library and fast math
-// support.
+// bit_cast<Dest, Source> is a template function that implements the equivalent
+// of "*reinterpret_cast<Dest*>(&source)". We need this in very low-level
+// functions like fast math support.
//
// float f = 3.14159265358979;
-// int i = lang_id_bit_cast<int32>(f);
+// int i = bit_cast<int32>(f);
// // i = 0x40490fdb
//
// The classical address-casting method is:
@@ -60,9 +59,9 @@
//
// Anyways ...
//
-// lang_id_bit_cast<> calls memcpy() which is blessed by the standard,
-// especially by the example in section 3.9 . Also, of course,
-// lang_id_bit_cast<> wraps up the nasty logic in one place.
+// bit_cast<> calls memcpy() which is blessed by the standard, especially by the
+// example in section 3.9 . Also, of course, bit_cast<> wraps up the nasty
+// logic in one place.
//
// Fortunately memcpy() is very fast. In optimized mode, with a
// constant size, gcc 2.95.3, gcc 4.0.1, and msvc 7.1 produce inline
@@ -70,15 +69,14 @@
// memcpy(d,s,4) compiles to one load and one store, and memcpy(d,s,8)
// compiles to two loads and two stores.
//
-// I tested this code with gcc 2.95.3, gcc 4.0.1, icc 8.1, and msvc 7.1.
+// Mike Chastain tested this code with gcc 2.95.3, gcc 4.0.1, icc 8.1, and msvc
+// 7.1.
//
// WARNING: if Dest or Source is a non-POD type, the result of the memcpy
// is likely to surprise you.
//
// Props to Bill Gibbons for the compile time assertion technique and
// Art Komninos and Igor Tandetnik for the msvc experiments.
-//
-// -- mec 2005-10-17
template <class Dest, class Source>
inline Dest bit_cast(const Source &source) {
diff --git a/base.h b/util/base/endian.h
similarity index 89%
rename from base.h
rename to util/base/endian.h
index 6829f1c..5813288 100644
--- a/base.h
+++ b/util/base/endian.h
@@ -14,25 +14,13 @@
* limitations under the License.
*/
-#ifndef LIBTEXTCLASSIFIER_BASE_H_
-#define LIBTEXTCLASSIFIER_BASE_H_
+#ifndef LIBTEXTCLASSIFIER_UTIL_BASE_ENDIAN_H_
+#define LIBTEXTCLASSIFIER_UTIL_BASE_ENDIAN_H_
-#include <cassert>
-#include <map>
-#include <string>
-#include <vector>
-
-#include "util/base/config.h"
#include "util/base/integral_types.h"
namespace libtextclassifier {
-#ifdef INTERNAL_BUILD
-typedef basic_string<char> bstring;
-#else
-typedef std::basic_string<char> bstring;
-#endif // INTERNAL_BUILD
-
#if defined OS_LINUX || defined OS_CYGWIN || defined OS_ANDROID || \
defined(__ANDROID__)
#include <endian.h>
@@ -126,4 +114,4 @@
} // namespace libtextclassifier
-#endif // LIBTEXTCLASSIFIER_BASE_H_
+#endif // LIBTEXTCLASSIFIER_UTIL_BASE_ENDIAN_H_
diff --git a/util/base/logging.h b/util/base/logging.h
index b0f3c5d..dba0ed4 100644
--- a/util/base/logging.h
+++ b/util/base/logging.h
@@ -24,6 +24,23 @@
#include "util/base/logging_levels.h"
#include "util/base/port.h"
+// TC_STRIP
+namespace libtextclassifier {
+// string class that can't be instantiated. Makes sure that the code does not
+// compile when non std::string is used.
+//
+// NOTE: defined here because most files directly or transitively include this
+// file. Asking people to include a special header just to make sure they don't
+// use the unqualified string doesn't work: as that header doesn't produce any
+// immediate benefit, one can easily forget about it.
+class string {
+ public:
+ // Makes the class non-instantiable.
+ virtual ~string() = 0;
+};
+} // namespace libtextclassifier
+// TC_END_STRIP
+
namespace libtextclassifier {
namespace logging {
@@ -75,10 +92,6 @@
#define TC_CHECK_GE(x, y) TC_CHECK((x) >= (y))
#define TC_CHECK_NE(x, y) TC_CHECK((x) != (y))
-// Debug checks: a TC_DCHECK<suffix> macro should behave like TC_CHECK<suffix>
-// in debug mode an don't check / don't print anything in non-debug mode.
-#ifdef NDEBUG
-
// Pseudo-stream that "eats" the tokens <<-pumped into it, without printing
// anything.
class NullStream {
@@ -92,6 +105,11 @@
}
#define TC_NULLSTREAM ::libtextclassifier::logging::NullStream().stream()
+
+// Debug checks: a TC_DCHECK<suffix> macro should behave like TC_CHECK<suffix>
+// in debug mode an don't check / don't print anything in non-debug mode.
+#ifdef NDEBUG
+
#define TC_DCHECK(x) TC_NULLSTREAM
#define TC_DCHECK_EQ(x, y) TC_NULLSTREAM
#define TC_DCHECK_LT(x, y) TC_NULLSTREAM
@@ -113,6 +131,16 @@
#define TC_DCHECK_NE(x, y) TC_CHECK_NE(x, y)
#endif // NDEBUG
+
+#ifdef LIBTEXTCLASSIFIER_VLOG
+#define TC_VLOG(severity) \
+ ::libtextclassifier::logging::LogMessage(::libtextclassifier::logging::INFO, \
+ __FILE__, __LINE__) \
+ .stream()
+#else
+#define TC_VLOG(severity) TC_NULLSTREAM
+#endif
+
} // namespace logging
} // namespace libtextclassifier
diff --git a/util/hash/farmhash.cc b/util/hash/farmhash.cc
index 55786a9..f4f2e84 100644
--- a/util/hash/farmhash.cc
+++ b/util/hash/farmhash.cc
@@ -642,7 +642,7 @@
uint32_t Hash32(const char *s, size_t len) {
FARMHASH_DIE_IF_MISCONFIGURED;
- return s == NULL ? 0 : len;
+ return s == nullptr ? 0 : len;
}
uint32_t Hash32WithSeed(const char *s, size_t len, uint32_t seed) {
@@ -865,7 +865,7 @@
uint32_t Hash32(const char *s, size_t len) {
FARMHASH_DIE_IF_MISCONFIGURED;
- return s == NULL ? 0 : len;
+ return s == nullptr ? 0 : len;
}
uint32_t Hash32WithSeed(const char *s, size_t len, uint32_t seed) {
@@ -894,7 +894,7 @@
uint32_t Hash32(const char *s, size_t len) {
FARMHASH_DIE_IF_MISCONFIGURED;
- return s == NULL ? 0 : len;
+ return s == nullptr ? 0 : len;
}
uint32_t Hash32WithSeed(const char *s, size_t len, uint32_t seed) {
diff --git a/util/java/scoped_local_ref.h b/util/java/scoped_local_ref.h
new file mode 100644
index 0000000..d995468
--- /dev/null
+++ b/util/java/scoped_local_ref.h
@@ -0,0 +1,65 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTIL_JAVA_SCOPED_LOCAL_REF_H_
+#define LIBTEXTCLASSIFIER_UTIL_JAVA_SCOPED_LOCAL_REF_H_
+
+#include <jni.h>
+#include <memory>
+#include <type_traits>
+
+#include "util/base/logging.h"
+
+namespace libtextclassifier {
+
+// A deleter to be used with std::unique_ptr to delete JNI local references.
+class LocalRefDeleter {
+ public:
+ // Style guide violating implicit constructor so that the LocalRefDeleter
+ // is implicitly constructed from the second argument to ScopedLocalRef.
+ LocalRefDeleter(JNIEnv* env) : env_(env) {} // NOLINT(runtime/explicit)
+
+ LocalRefDeleter(const LocalRefDeleter& orig) = default;
+
+ // Copy assignment to allow move semantics in ScopedLocalRef.
+ LocalRefDeleter& operator=(const LocalRefDeleter& rhs) {
+ // As the deleter and its state are thread-local, ensure the envs
+ // are consistent but do nothing.
+ TC_CHECK_EQ(env_, rhs.env_);
+ return *this;
+ }
+
+ // The delete operator.
+ void operator()(jobject o) const { env_->DeleteLocalRef(o); }
+
+ private:
+ // The env_ stashed to use for deletion. Thread-local, don't share!
+ JNIEnv* const env_;
+};
+
+// A smart pointer that deletes a JNI local reference when it goes out
+// of scope. Usage is:
+// ScopedLocalRef<jobject> scoped_local(env->JniFunction(), env);
+//
+// Note that this class is not thread-safe since it caches JNIEnv in
+// the deleter. Do not use the same jobject across different threads.
+template <typename T>
+using ScopedLocalRef =
+ std::unique_ptr<typename std::remove_pointer<T>::type, LocalRefDeleter>;
+
+} // namespace libtextclassifier
+
+#endif // LIBTEXTCLASSIFIER_UTIL_JAVA_SCOPED_LOCAL_REF_H_
diff --git a/tests/numbers_test.cc b/util/strings/numbers_test.cc
similarity index 100%
rename from tests/numbers_test.cc
rename to util/strings/numbers_test.cc
diff --git a/util/utf8/unicodetext.cc b/util/utf8/unicodetext.cc
index e83c890..dbab1c8 100644
--- a/util/utf8/unicodetext.cc
+++ b/util/utf8/unicodetext.cc
@@ -16,7 +16,10 @@
#include "util/utf8/unicodetext.h"
-#include "base.h"
+#include <string.h>
+
+#include <algorithm>
+
#include "util/strings/utf8.h"
namespace libtextclassifier {
@@ -108,6 +111,8 @@
void UnicodeText::clear() { repr_.clear(); }
+int UnicodeText::size() const { return std::distance(begin(), end()); }
+
std::string UnicodeText::UTF8Substring(const const_iterator& first,
const const_iterator& last) {
return std::string(first.it_, last.it_ - first.it_);
diff --git a/util/utf8/unicodetext.h b/util/utf8/unicodetext.h
index 5327383..6a21058 100644
--- a/util/utf8/unicodetext.h
+++ b/util/utf8/unicodetext.h
@@ -17,9 +17,11 @@
#ifndef LIBTEXTCLASSIFIER_UTIL_UTF8_UNICODETEXT_H_
#define LIBTEXTCLASSIFIER_UTIL_UTF8_UNICODETEXT_H_
+#include <iterator>
+#include <string>
#include <utility>
-#include "base.h"
+#include "util/base/integral_types.h"
namespace libtextclassifier {
@@ -137,6 +139,7 @@
const_iterator begin() const;
const_iterator end() const;
+ int size() const; // the number of Unicode characters (codepoints)
// x.PointToUTF8(buf,len) changes x so that it points to buf
// ("becomes an alias"). It does not take ownership or copy buf.
@@ -162,7 +165,7 @@
int capacity_;
bool ours_; // Do we own data_?
- Repr() : data_(NULL), size_(0), capacity_(0), ours_(true) {}
+ Repr() : data_(nullptr), size_(0), capacity_(0), ours_(true) {}
~Repr() {
if (ours_) delete[] data_;
}