Sync with Google3. am: d2c364b75f
am: 0206f5b9cd
Change-Id: I22d7dff37de9ffac4424b821f4280bed9a265443
diff --git a/models/textclassifier.smartselection.en.model b/models/textclassifier.smartselection.en.model
index 9b35fdd..850033a 100644
--- a/models/textclassifier.smartselection.en.model
+++ b/models/textclassifier.smartselection.en.model
Binary files differ
diff --git a/smartselect/text-classification-model.cc b/smartselect/text-classification-model.cc
index 71457bc..18a15eb 100644
--- a/smartselect/text-classification-model.cc
+++ b/smartselect/text-classification-model.cc
@@ -132,6 +132,23 @@
};
}
+void ParseMergedModel(const MmapHandle& mmap_handle,
+ const char** selection_model, int* selection_model_length,
+ const char** sharing_model, int* sharing_model_length) {
+ // Read the length of the selection model.
+ const char* model_data = reinterpret_cast<const char*>(mmap_handle.start());
+ *selection_model_length =
+ LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(model_data));
+ model_data += sizeof(*selection_model_length);
+ *selection_model = model_data;
+ model_data += *selection_model_length;
+
+ *sharing_model_length =
+ LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(model_data));
+ model_data += sizeof(*sharing_model_length);
+ *sharing_model = model_data;
+}
+
} // namespace
bool TextClassificationModel::LoadModels(int fd) {
@@ -140,14 +157,13 @@
return false;
}
- // Read the length of the selection model.
- const char* model_data = reinterpret_cast<const char*>(mmap_handle.start());
- uint32 selection_model_length =
- LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(model_data));
- model_data += sizeof(selection_model_length);
+ const char *selection_model, *sharing_model;
+ int selection_model_length, sharing_model_length;
+ ParseMergedModel(mmap_handle, &selection_model, &selection_model_length,
+ &sharing_model, &sharing_model_length);
selection_params_.reset(
- ModelParamsBuilder(model_data, selection_model_length, nullptr));
+ ModelParamsBuilder(selection_model, selection_model_length, nullptr));
if (!selection_params_.get()) {
return false;
}
@@ -157,12 +173,8 @@
selection_feature_fn_ = CreateFeatureVectorFn(
*selection_network_, selection_network_->EmbeddingSize(0));
- model_data += selection_model_length;
- uint32 sharing_model_length =
- LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(model_data));
- model_data += sizeof(sharing_model_length);
sharing_params_.reset(
- ModelParamsBuilder(model_data, sharing_model_length,
+ ModelParamsBuilder(sharing_model, sharing_model_length,
selection_params_->GetEmbeddingParams()));
if (!sharing_params_.get()) {
return false;
@@ -176,6 +188,31 @@
return true;
}
+bool ReadSelectionModelOptions(int fd, ModelOptions* model_options) {
+ MmapHandle mmap_handle = MmapFile(fd);
+ if (!mmap_handle.ok()) {
+ TC_LOG(ERROR) << "Can't mmap.";
+ return false;
+ }
+
+ const char *selection_model, *sharing_model;
+ int selection_model_length, sharing_model_length;
+ ParseMergedModel(mmap_handle, &selection_model, &selection_model_length,
+ &sharing_model, &sharing_model_length);
+
+ MemoryImageReader<EmbeddingNetworkProto> reader(selection_model,
+ selection_model_length);
+
+ auto model_options_extension_id = model_options_in_embedding_network_proto;
+ if (reader.trimmed_proto().HasExtension(model_options_extension_id)) {
+ *model_options =
+ reader.trimmed_proto().GetExtension(model_options_extension_id);
+ return true;
+ } else {
+ return false;
+ }
+}
+
EmbeddingNetwork::Vector TextClassificationModel::InferInternal(
const std::string& context, CodepointSpan span,
const FeatureProcessor& feature_processor, const EmbeddingNetwork& network,
diff --git a/smartselect/text-classification-model.h b/smartselect/text-classification-model.h
index ae2049b..1bc6640 100644
--- a/smartselect/text-classification-model.h
+++ b/smartselect/text-classification-model.h
@@ -114,6 +114,10 @@
std::set<int> punctuation_to_strip_;
};
+// Parses the merged image given as a file descriptor, and reads
+// the ModelOptions proto from the selection model.
+bool ReadSelectionModelOptions(int fd, ModelOptions* model_options);
+
} // namespace libtextclassifier
#endif // LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_
diff --git a/smartselect/text-classification-model.proto b/smartselect/text-classification-model.proto
index 4e21af4..51d28c8 100644
--- a/smartselect/text-classification-model.proto
+++ b/smartselect/text-classification-model.proto
@@ -27,6 +27,12 @@
// If true, will use embeddings from a different model. This is mainly useful
// for the Sharing model using the embeddings from the Selection model.
optional bool use_shared_embeddings = 1;
+
+ // Language of the model.
+ optional string language = 2;
+
+ // Version of the model.
+ optional int32 version = 3;
}
message SelectionModelOptions {
@@ -104,13 +110,6 @@
// be mapped to an id.
optional int32 default_collection = 10 [default = -1];
- // Probability with which to drop context of examples.
- optional float context_dropout_probability = 11 [default = 0.0];
-
- // If true, drop variable amounts of context, if false all context, with
- // probability given by context_dropout_ratio.
- optional bool use_variable_context_dropout = 12 [default = false];
-
// If true, will split the input by lines, and only use the line that contains
// the clicked token.
optional bool only_use_line_with_click = 13 [default = false];
@@ -137,9 +136,6 @@
}
optional CenterTokenSelectionMethod center_token_selection_method = 16;
- // If true, during training will click random token in the example selection.
- optional bool click_random_token_in_selection = 17 [default = true];
-
// If true, span boundaries will be snapped to containing tokens and not
// required to exactly match token boundaries.
optional bool snap_label_span_boundaries_to_containing_tokens = 18;
@@ -167,27 +163,6 @@
// to it. So the resulting feature vector has two regions.
optional int32 feature_version = 25 [default = 0];
- // These settings control whether a distortion is applied to part of the data,
- // for Smart Sharing. Distortion means modifying (expanding) the bounds of the
- // selection and changing the example's collection to "other". The goal is to
- // expose the model to overselections as negative examples.
- // If true, distortion is applied. Otherwise the other settings are ignored.
- optional bool distortion_enable = 26;
- // Probability settings. They individual values and their sum should be in the
- // range [0, 1]. They specify the probability of the distorition being applied
- // to just one of the bounds (left or right with equal probabiolity), both
- // bounds, or not at all (the remaining probability).
- // If the context does not contain tokens on the given side of the selection,
- // the probabilistic decision is ignored. This means that the actual frequency
- // of distortion is somewhat lower than specified here.
- optional double distortion_probability_one_side = 27;
- optional double distortion_probability_both_sides = 28;
- // The maximum number of tokens to include (on one side) when distortion is
- // applied. The actual number is selected (independently for each side)
- // uniformly from integers from 1 to this value, inclusive. If the context is
- // too short, the end result is truncated.
- optional double distortion_max_num_tokens = 29;
-
// Controls the type of tokenization the model will use for the input text.
enum TokenizationType {
INVALID_TOKENIZATION_TYPE = 0;
@@ -202,7 +177,7 @@
[default = INTERNAL_TOKENIZER];
optional bool icu_preserve_whitespace_tokens = 31 [default = false];
- reserved 7;
+ reserved 7, 11, 12, 17, 26, 27, 28, 29, 32;
};
extend nlp_core.EmbeddingNetworkProto {
diff --git a/tests/testdata/smartselection.model b/tests/testdata/smartselection.model
index 4f133f0..850033a 100644
--- a/tests/testdata/smartselection.model
+++ b/tests/testdata/smartselection.model
Binary files differ
diff --git a/tests/text-classification-model_test.cc b/tests/text-classification-model_test.cc
index c52ad38..10da631 100644
--- a/tests/text-classification-model_test.cc
+++ b/tests/text-classification-model_test.cc
@@ -31,6 +31,17 @@
return TEST_DATA_DIR "smartselection.model";
}
+TEST(TextClassificationModelTest, ReadModelOptions) {
+ const std::string model_path = GetModelPath();
+ int fd = open(model_path.c_str(), O_RDONLY);
+ ModelOptions model_options;
+ ASSERT_TRUE(ReadSelectionModelOptions(fd, &model_options));
+ close(fd);
+
+ EXPECT_EQ("en", model_options.language());
+ EXPECT_GT(model_options.version(), 0);
+}
+
TEST(TextClassificationModelTest, SuggestSelection) {
const std::string model_path = GetModelPath();
int fd = open(model_path.c_str(), O_RDONLY);
diff --git a/textclassifier_jni.cc b/textclassifier_jni.cc
index 66e2934..4ce08a0 100644
--- a/textclassifier_jni.cc
+++ b/textclassifier_jni.cc
@@ -47,6 +47,16 @@
jobject thiz,
jlong ptr);
+JNIEXPORT jstring JNICALL
+Java_android_view_textclassifier_SmartSelection_nativeGetLanguage(JNIEnv* env,
+ jobject thiz,
+ jint fd);
+
+JNIEXPORT jint JNICALL
+Java_android_view_textclassifier_SmartSelection_nativeGetVersion(JNIEnv* env,
+ jobject thiz,
+ jint fd);
+
// LangId.
JNIEXPORT jlong JNICALL Java_android_view_textclassifier_LangId_nativeNew(
JNIEnv* env, jobject thiz, jint fd);
@@ -65,15 +75,44 @@
#endif
using libtextclassifier::TextClassificationModel;
+using libtextclassifier::ModelOptions;
using libtextclassifier::nlp_core::lang_id::LangId;
namespace {
-std::string ToStlString(JNIEnv* env, jstring str) {
- const char* bytes = env->GetStringUTFChars(str, 0);
- const std::string s = bytes;
- env->ReleaseStringUTFChars(str, bytes);
- return s;
+bool JStringToUtf8String(JNIEnv* env, const jstring& jstr,
+ std::string* result) {
+ if (jstr == nullptr) {
+ *result = std::string();
+ return false;
+ }
+
+ jclass string_class = env->FindClass("java/lang/String");
+ jmethodID get_bytes_id =
+ env->GetMethodID(string_class, "getBytes", "(Ljava/lang/String;)[B");
+
+ jstring encoding = env->NewStringUTF("UTF-8");
+ jbyteArray array = reinterpret_cast<jbyteArray>(
+ env->CallObjectMethod(jstr, get_bytes_id, encoding));
+
+ jbyte* const array_bytes = env->GetByteArrayElements(array, JNI_FALSE);
+ int length = env->GetArrayLength(array);
+
+ *result = std::string(reinterpret_cast<char*>(array_bytes), length);
+
+ // Release the array.
+ env->ReleaseByteArrayElements(array, array_bytes, JNI_ABORT);
+ env->DeleteLocalRef(array);
+ env->DeleteLocalRef(string_class);
+ env->DeleteLocalRef(encoding);
+
+ return true;
+}
+
+std::string ToStlString(JNIEnv* env, const jstring& str) {
+ std::string result;
+ JStringToUtf8String(env, str, &result);
+ return result;
}
jobjectArray ScoredStringsToJObjectArray(
@@ -154,6 +193,30 @@
return reinterpret_cast<jlong>(new LangId(fd));
}
+JNIEXPORT jstring JNICALL
+Java_android_view_textclassifier_SmartSelection_nativeGetLanguage(JNIEnv* env,
+ jobject clazz,
+ jint fd) {
+ ModelOptions model_options;
+ if (ReadSelectionModelOptions(fd, &model_options)) {
+ return env->NewStringUTF(model_options.language().c_str());
+ } else {
+ return env->NewStringUTF("UNK");
+ }
+}
+
+JNIEXPORT jint JNICALL
+Java_android_view_textclassifier_SmartSelection_nativeGetVersion(JNIEnv* env,
+ jobject clazz,
+ jint fd) {
+ ModelOptions model_options;
+ if (ReadSelectionModelOptions(fd, &model_options)) {
+ return model_options.version();
+ } else {
+ return -1;
+ }
+}
+
JNIEXPORT jobjectArray JNICALL
Java_android_view_textclassifier_LangId_nativeFindLanguages(JNIEnv* env,
jobject thiz,