Sync with Google3. am: d2c364b75f
am: 0206f5b9cd

Change-Id: I22d7dff37de9ffac4424b821f4280bed9a265443
diff --git a/models/textclassifier.smartselection.en.model b/models/textclassifier.smartselection.en.model
index 9b35fdd..850033a 100644
--- a/models/textclassifier.smartselection.en.model
+++ b/models/textclassifier.smartselection.en.model
Binary files differ
diff --git a/smartselect/text-classification-model.cc b/smartselect/text-classification-model.cc
index 71457bc..18a15eb 100644
--- a/smartselect/text-classification-model.cc
+++ b/smartselect/text-classification-model.cc
@@ -132,6 +132,23 @@
   };
 }
 
+void ParseMergedModel(const MmapHandle& mmap_handle,
+                      const char** selection_model, int* selection_model_length,
+                      const char** sharing_model, int* sharing_model_length) {
+  // Read the length of the selection model.
+  const char* model_data = reinterpret_cast<const char*>(mmap_handle.start());
+  *selection_model_length =
+      LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(model_data));
+  model_data += sizeof(*selection_model_length);
+  *selection_model = model_data;
+  model_data += *selection_model_length;
+
+  *sharing_model_length =
+      LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(model_data));
+  model_data += sizeof(*sharing_model_length);
+  *sharing_model = model_data;
+}
+
 }  // namespace
 
 bool TextClassificationModel::LoadModels(int fd) {
@@ -140,14 +157,13 @@
     return false;
   }
 
-  // Read the length of the selection model.
-  const char* model_data = reinterpret_cast<const char*>(mmap_handle.start());
-  uint32 selection_model_length =
-      LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(model_data));
-  model_data += sizeof(selection_model_length);
+  const char *selection_model, *sharing_model;
+  int selection_model_length, sharing_model_length;
+  ParseMergedModel(mmap_handle, &selection_model, &selection_model_length,
+                   &sharing_model, &sharing_model_length);
 
   selection_params_.reset(
-      ModelParamsBuilder(model_data, selection_model_length, nullptr));
+      ModelParamsBuilder(selection_model, selection_model_length, nullptr));
   if (!selection_params_.get()) {
     return false;
   }
@@ -157,12 +173,8 @@
   selection_feature_fn_ = CreateFeatureVectorFn(
       *selection_network_, selection_network_->EmbeddingSize(0));
 
-  model_data += selection_model_length;
-  uint32 sharing_model_length =
-      LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(model_data));
-  model_data += sizeof(sharing_model_length);
   sharing_params_.reset(
-      ModelParamsBuilder(model_data, sharing_model_length,
+      ModelParamsBuilder(sharing_model, sharing_model_length,
                          selection_params_->GetEmbeddingParams()));
   if (!sharing_params_.get()) {
     return false;
@@ -176,6 +188,31 @@
   return true;
 }
 
+bool ReadSelectionModelOptions(int fd, ModelOptions* model_options) {
+  MmapHandle mmap_handle = MmapFile(fd);
+  if (!mmap_handle.ok()) {
+    TC_LOG(ERROR) << "Can't mmap.";
+    return false;
+  }
+
+  const char *selection_model, *sharing_model;
+  int selection_model_length, sharing_model_length;
+  ParseMergedModel(mmap_handle, &selection_model, &selection_model_length,
+                   &sharing_model, &sharing_model_length);
+
+  MemoryImageReader<EmbeddingNetworkProto> reader(selection_model,
+                                                  selection_model_length);
+
+  auto model_options_extension_id = model_options_in_embedding_network_proto;
+  if (reader.trimmed_proto().HasExtension(model_options_extension_id)) {
+    *model_options =
+        reader.trimmed_proto().GetExtension(model_options_extension_id);
+    return true;
+  } else {
+    return false;
+  }
+}
+
 EmbeddingNetwork::Vector TextClassificationModel::InferInternal(
     const std::string& context, CodepointSpan span,
     const FeatureProcessor& feature_processor, const EmbeddingNetwork& network,
diff --git a/smartselect/text-classification-model.h b/smartselect/text-classification-model.h
index ae2049b..1bc6640 100644
--- a/smartselect/text-classification-model.h
+++ b/smartselect/text-classification-model.h
@@ -114,6 +114,10 @@
   std::set<int> punctuation_to_strip_;
 };
 
+// Parses the merged image given as a file descriptor, and reads
+// the ModelOptions proto from the selection model.
+bool ReadSelectionModelOptions(int fd, ModelOptions* model_options);
+
 }  // namespace libtextclassifier
 
 #endif  // LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_
diff --git a/smartselect/text-classification-model.proto b/smartselect/text-classification-model.proto
index 4e21af4..51d28c8 100644
--- a/smartselect/text-classification-model.proto
+++ b/smartselect/text-classification-model.proto
@@ -27,6 +27,12 @@
   // If true, will use embeddings from a different model. This is mainly useful
   // for the Sharing model using the embeddings from the Selection model.
   optional bool use_shared_embeddings = 1;
+
+  // Language of the model.
+  optional string language = 2;
+
+  // Version of the model.
+  optional int32 version = 3;
 }
 
 message SelectionModelOptions {
@@ -104,13 +110,6 @@
   // be mapped to an id.
   optional int32 default_collection = 10 [default = -1];
 
-  // Probability with which to drop context of examples.
-  optional float context_dropout_probability = 11 [default = 0.0];
-
-  // If true, drop variable amounts of context, if false all context, with
-  // probability given by context_dropout_ratio.
-  optional bool use_variable_context_dropout = 12 [default = false];
-
   // If true, will split the input by lines, and only use the line that contains
   // the clicked token.
   optional bool only_use_line_with_click = 13 [default = false];
@@ -137,9 +136,6 @@
   }
   optional CenterTokenSelectionMethod center_token_selection_method = 16;
 
-  // If true, during training will click random token in the example selection.
-  optional bool click_random_token_in_selection = 17 [default = true];
-
   // If true, span boundaries will be snapped to containing tokens and not
   // required to exactly match token boundaries.
   optional bool snap_label_span_boundaries_to_containing_tokens = 18;
@@ -167,27 +163,6 @@
   //      to it. So the resulting feature vector has two regions.
   optional int32 feature_version = 25 [default = 0];
 
-  // These settings control whether a distortion is applied to part of the data,
-  // for Smart Sharing. Distortion means modifying (expanding) the bounds of the
-  // selection and changing the example's collection to "other". The goal is to
-  // expose the model to overselections as negative examples.
-  // If true, distortion is applied. Otherwise the other settings are ignored.
-  optional bool distortion_enable = 26;
-  // Probability settings. They individual values and their sum should be in the
-  // range [0, 1]. They specify the probability of the distorition being applied
-  // to just one of the bounds (left or right with equal probabiolity), both
-  // bounds, or not at all (the remaining probability).
-  // If the context does not contain tokens on the given side of the selection,
-  // the probabilistic decision is ignored. This means that the actual frequency
-  // of distortion is somewhat lower than specified here.
-  optional double distortion_probability_one_side = 27;
-  optional double distortion_probability_both_sides = 28;
-  // The maximum number of tokens to include (on one side) when distortion is
-  // applied. The actual number is selected (independently for each side)
-  // uniformly from integers from 1 to this value, inclusive. If the context is
-  // too short, the end result is truncated.
-  optional double distortion_max_num_tokens = 29;
-
   // Controls the type of tokenization the model will use for the input text.
   enum TokenizationType {
     INVALID_TOKENIZATION_TYPE = 0;
@@ -202,7 +177,7 @@
       [default = INTERNAL_TOKENIZER];
   optional bool icu_preserve_whitespace_tokens = 31 [default = false];
 
-  reserved 7;
+  reserved 7, 11, 12, 17, 26, 27, 28, 29, 32;
 };
 
 extend nlp_core.EmbeddingNetworkProto {
diff --git a/tests/testdata/smartselection.model b/tests/testdata/smartselection.model
index 4f133f0..850033a 100644
--- a/tests/testdata/smartselection.model
+++ b/tests/testdata/smartselection.model
Binary files differ
diff --git a/tests/text-classification-model_test.cc b/tests/text-classification-model_test.cc
index c52ad38..10da631 100644
--- a/tests/text-classification-model_test.cc
+++ b/tests/text-classification-model_test.cc
@@ -31,6 +31,17 @@
   return TEST_DATA_DIR "smartselection.model";
 }
 
+TEST(TextClassificationModelTest, ReadModelOptions) {
+  const std::string model_path = GetModelPath();
+  int fd = open(model_path.c_str(), O_RDONLY);
+  ModelOptions model_options;
+  ASSERT_TRUE(ReadSelectionModelOptions(fd, &model_options));
+  close(fd);
+
+  EXPECT_EQ("en", model_options.language());
+  EXPECT_GT(model_options.version(), 0);
+}
+
 TEST(TextClassificationModelTest, SuggestSelection) {
   const std::string model_path = GetModelPath();
   int fd = open(model_path.c_str(), O_RDONLY);
diff --git a/textclassifier_jni.cc b/textclassifier_jni.cc
index 66e2934..4ce08a0 100644
--- a/textclassifier_jni.cc
+++ b/textclassifier_jni.cc
@@ -47,6 +47,16 @@
                                                             jobject thiz,
                                                             jlong ptr);
 
+JNIEXPORT jstring JNICALL
+Java_android_view_textclassifier_SmartSelection_nativeGetLanguage(JNIEnv* env,
+                                                                  jobject thiz,
+                                                                  jint fd);
+
+JNIEXPORT jint JNICALL
+Java_android_view_textclassifier_SmartSelection_nativeGetVersion(JNIEnv* env,
+                                                                 jobject thiz,
+                                                                 jint fd);
+
 // LangId.
 JNIEXPORT jlong JNICALL Java_android_view_textclassifier_LangId_nativeNew(
     JNIEnv* env, jobject thiz, jint fd);
@@ -65,15 +75,44 @@
 #endif
 
 using libtextclassifier::TextClassificationModel;
+using libtextclassifier::ModelOptions;
 using libtextclassifier::nlp_core::lang_id::LangId;
 
 namespace {
 
-std::string ToStlString(JNIEnv* env, jstring str) {
-  const char* bytes = env->GetStringUTFChars(str, 0);
-  const std::string s = bytes;
-  env->ReleaseStringUTFChars(str, bytes);
-  return s;
+bool JStringToUtf8String(JNIEnv* env, const jstring& jstr,
+                         std::string* result) {
+  if (jstr == nullptr) {
+    *result = std::string();
+    return false;
+  }
+
+  jclass string_class = env->FindClass("java/lang/String");
+  jmethodID get_bytes_id =
+      env->GetMethodID(string_class, "getBytes", "(Ljava/lang/String;)[B");
+
+  jstring encoding = env->NewStringUTF("UTF-8");
+  jbyteArray array = reinterpret_cast<jbyteArray>(
+      env->CallObjectMethod(jstr, get_bytes_id, encoding));
+
+  jbyte* const array_bytes = env->GetByteArrayElements(array, JNI_FALSE);
+  int length = env->GetArrayLength(array);
+
+  *result = std::string(reinterpret_cast<char*>(array_bytes), length);
+
+  // Release the array.
+  env->ReleaseByteArrayElements(array, array_bytes, JNI_ABORT);
+  env->DeleteLocalRef(array);
+  env->DeleteLocalRef(string_class);
+  env->DeleteLocalRef(encoding);
+
+  return true;
+}
+
+std::string ToStlString(JNIEnv* env, const jstring& str) {
+  std::string result;
+  JStringToUtf8String(env, str, &result);
+  return result;
 }
 
 jobjectArray ScoredStringsToJObjectArray(
@@ -154,6 +193,30 @@
   return reinterpret_cast<jlong>(new LangId(fd));
 }
 
+JNIEXPORT jstring JNICALL
+Java_android_view_textclassifier_SmartSelection_nativeGetLanguage(JNIEnv* env,
+                                                                  jobject clazz,
+                                                                  jint fd) {
+  ModelOptions model_options;
+  if (ReadSelectionModelOptions(fd, &model_options)) {
+    return env->NewStringUTF(model_options.language().c_str());
+  } else {
+    return env->NewStringUTF("UNK");
+  }
+}
+
+JNIEXPORT jint JNICALL
+Java_android_view_textclassifier_SmartSelection_nativeGetVersion(JNIEnv* env,
+                                                                 jobject clazz,
+                                                                 jint fd) {
+  ModelOptions model_options;
+  if (ReadSelectionModelOptions(fd, &model_options)) {
+    return model_options.version();
+  } else {
+    return -1;
+  }
+}
+
 JNIEXPORT jobjectArray JNICALL
 Java_android_view_textclassifier_LangId_nativeFindLanguages(JNIEnv* env,
                                                             jobject thiz,