Snap for 8152310 from faf03992dc4e169d214b17726a82f664efd6b57a to mainline-media-swcodec-release

Change-Id: I0ba3ed5e29e26af2034e05fc9e46dac96bfa4225
diff --git a/java/src/com/android/textclassifier/common/logging/ResultIdUtils.java b/java/src/com/android/textclassifier/common/logging/ResultIdUtils.java
index dae0442..67e300d 100644
--- a/java/src/com/android/textclassifier/common/logging/ResultIdUtils.java
+++ b/java/src/com/android/textclassifier/common/logging/ResultIdUtils.java
@@ -66,8 +66,8 @@
   }
 
   /** Returns if the result id was generated from the default text classifier. */
-  public static boolean isFromDefaultTextClassifier(String resultId) {
-    return resultId.startsWith(CLASSIFIER_ID + '|');
+  public static boolean isFromDefaultTextClassifier(@Nullable String resultId) {
+    return resultId != null && resultId.startsWith(CLASSIFIER_ID + '|');
   }
 
   /** Returns all the model names encoded in the signature. */
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
index 0e3842c..71f9a4f 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
@@ -52,16 +52,21 @@
 import java.util.stream.Collectors;
 import org.junit.After;
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
 import org.mockito.Mockito;
-import org.mockito.MockitoAnnotations;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
 
 @SmallTest
 @RunWith(AndroidJUnit4.class)
 public class DefaultTextClassifierServiceTest {
+
+  @Rule public final MockitoRule mocks = MockitoJUnit.rule();
+
   /** A statsd config ID, which is arbitrary. */
   private static final long CONFIG_ID = 689777;
 
@@ -79,7 +84,6 @@
 
   @Before
   public void setup() {
-    MockitoAnnotations.initMocks(this);
 
     testInjector = new TestInjector(ApplicationProvider.getApplicationContext());
     defaultTextClassifierService = new DefaultTextClassifierService(testInjector);
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java
index 5297640..20ae592 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java
@@ -53,7 +53,8 @@
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
 
 @SmallTest
 @RunWith(AndroidJUnit4.class)
@@ -67,6 +68,7 @@
   @Mock private DownloadedModelManager downloadedModelManager;
 
   @Rule public final SetDefaultLocalesRule setDefaultLocalesRule = new SetDefaultLocalesRule();
+  @Rule public final MockitoRule mocks = MockitoJUnit.rule();
 
   private File rootTestDir;
   private ModelFileManagerImpl modelFileManager;
@@ -75,7 +77,6 @@
 
   @Before
   public void setup() {
-    MockitoAnnotations.initMocks(this);
     deviceConfig = new TestingDeviceConfig();
     rootTestDir =
         new File(ApplicationProvider.getApplicationContext().getCacheDir(), "rootTestDir");
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java
index 216cd5d..3aab211 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java
@@ -27,14 +27,18 @@
 import com.google.android.textclassifier.RemoteActionTemplate;
 import java.util.List;
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
-import org.mockito.MockitoAnnotations;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
 
 @SmallTest
 @RunWith(AndroidJUnit4.class)
 public class TemplateIntentFactoryTest {
 
+  @Rule public final MockitoRule mocks = MockitoJUnit.rule();
+
   private static final String TITLE_WITHOUT_ENTITY = "Map";
   private static final String TITLE_WITH_ENTITY = "Map NW14D1";
   private static final String DESCRIPTION = "Check the map";
@@ -71,7 +75,6 @@
 
   @Before
   public void setup() {
-    MockitoAnnotations.initMocks(this);
     templateIntentFactory = new TemplateIntentFactory();
   }
 
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java
index ffd2ee4..3a8fefc 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java
@@ -86,8 +86,8 @@
     return ImmutableList.copyOf(
         metricsList.stream()
             .flatMap(statsLogReport -> statsLogReport.getEventMetrics().getDataList().stream())
-            .flatMap(eventMetricData -> backfillAggregatedAtomsinEventMetric(
-                    eventMetricData).stream())
+            .flatMap(
+                eventMetricData -> backfillAggregatedAtomsinEventMetric(eventMetricData).stream())
             .sorted(Comparator.comparing(EventMetricData::getElapsedTimestampNanos))
             .map(EventMetricData::getAtom)
             .collect(Collectors.toList()));
@@ -136,7 +136,7 @@
   }
 
   private static ImmutableList<EventMetricData> backfillAggregatedAtomsinEventMetric(
-    EventMetricData metricData) {
+      EventMetricData metricData) {
     if (metricData.hasAtom()) {
       return ImmutableList.of(metricData);
     }
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java
index c626ed7..394b7ad 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java
@@ -46,7 +46,8 @@
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
 
 @RunWith(AndroidJUnit4.class)
 public final class ModelDownloadManagerTest {
@@ -61,6 +62,8 @@
   public final TextClassifierDownloadLoggerTestRule loggerTestRule =
       new TextClassifierDownloadLoggerTestRule();
 
+  @Rule public final MockitoRule mocks = MockitoJUnit.rule();
+
   private TestingDeviceConfig deviceConfig;
   private WorkManager workManager;
   private ModelDownloadManager downloadManager;
@@ -68,7 +71,6 @@
 
   @Before
   public void setUp() {
-    MockitoAnnotations.initMocks(this);
     Context context = ApplicationProvider.getApplicationContext();
     WorkManagerTestInitHelper.initializeTestWorkManager(context);
 
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java
index eac2af3..76d04e0 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java
@@ -37,14 +37,19 @@
 import java.io.File;
 import java.net.URI;
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
 
 @RunWith(JUnit4.class)
 public final class ModelDownloaderServiceImplTest {
+
+  @Rule public final MockitoRule mocks = MockitoJUnit.rule();
+
   private static final long BYTES_WRITTEN = 1L;
   private static final String DOWNLOAD_URI =
       "https://www.gstatic.com/android/text_classifier/r/v999/en.fb";
@@ -66,7 +71,6 @@
 
   @Before
   public void setUp() {
-    MockitoAnnotations.initMocks(this);
 
     this.targetModelFile =
         new File(ApplicationProvider.getApplicationContext().getCacheDir(), "model.fb");
diff --git a/native/actions/actions-entity-data.bfbs b/native/actions/actions-entity-data.bfbs
index 7421579..6ebf1cf 100644
--- a/native/actions/actions-entity-data.bfbs
+++ b/native/actions/actions-entity-data.bfbs
Binary files differ
diff --git a/native/actions/actions-entity-data.fbs b/native/actions/actions-entity-data.fbs
index 21584b6..e906f93 100644
--- a/native/actions/actions-entity-data.fbs
+++ b/native/actions/actions-entity-data.fbs
@@ -18,7 +18,7 @@
 namespace libtextclassifier3;
 table ActionsEntityData {
   // Extracted text.
-  text:string (shared);
+  text:string (key, shared);
 }
 
 root_type libtextclassifier3.ActionsEntityData;
diff --git a/native/actions/actions-suggestions.cc b/native/actions/actions-suggestions.cc
index b1a042c..9f9a8d4 100644
--- a/native/actions/actions-suggestions.cc
+++ b/native/actions/actions-suggestions.cc
@@ -17,6 +17,7 @@
 #include "actions/actions-suggestions.h"
 
 #include <memory>
+#include <string>
 #include <vector>
 
 #include "utils/base/statusor.h"
@@ -40,6 +41,7 @@
 #include "utils/strings/stringpiece.h"
 #include "utils/strings/utf8.h"
 #include "utils/utf8/unicodetext.h"
+#include "absl/container/flat_hash_set.h"
 #include "tensorflow/lite/string_util.h"
 
 namespace libtextclassifier3 {
@@ -809,12 +811,14 @@
 
 void ActionsSuggestions::PopulateTextReplies(
     const tflite::Interpreter* interpreter, int suggestion_index,
-    int score_index, const std::string& type,
+    int score_index, const std::string& type, float priority_score,
+    const absl::flat_hash_set<std::string>& blocklist,
     ActionsSuggestionsResponse* response) const {
   const std::vector<tflite::StringRef> replies =
       model_executor_->Output<tflite::StringRef>(suggestion_index, interpreter);
   const TensorView<float> scores =
       model_executor_->OutputView<float>(score_index, interpreter);
+
   for (int i = 0; i < replies.size(); i++) {
     if (replies[i].len == 0) {
       continue;
@@ -823,8 +827,12 @@
     if (score < preconditions_.min_reply_score_threshold) {
       continue;
     }
-    response->actions.push_back(
-        {std::string(replies[i].str, replies[i].len), type, score});
+    std::string response_text(replies[i].str, replies[i].len);
+    if (blocklist.contains(response_text)) {
+      continue;
+    }
+
+    response->actions.push_back({response_text, type, score, priority_score});
   }
 }
 
@@ -909,10 +917,12 @@
   // Read smart reply predictions.
   if (!response->output_filtered_min_triggering_score &&
       model_->tflite_model_spec()->output_replies() >= 0) {
+    absl::flat_hash_set<std::string> empty_blocklist;
     PopulateTextReplies(interpreter,
                         model_->tflite_model_spec()->output_replies(),
                         model_->tflite_model_spec()->output_replies_scores(),
-                        model_->smart_reply_action_type()->str(), response);
+                        model_->smart_reply_action_type()->str(),
+                        /* priority_score */ 0.0, empty_blocklist, response);
   }
 
   // Read actions suggestions.
@@ -950,17 +960,26 @@
       const int suggestions_index = metadata->output_suggestions();
       const int suggestions_scores_index =
           metadata->output_suggestions_scores();
+      absl::flat_hash_set<std::string> response_text_blocklist;
       switch (metadata->prediction_type()) {
         case PredictionType_NEXT_MESSAGE_PREDICTION:
           if (!task_spec || task_spec->type()->size() == 0) {
             TC3_LOG(WARNING) << "Task type not provided, use default "
                                 "smart_reply_action_type!";
           }
+          if (task_spec) {
+            if (task_spec->response_text_blocklist()) {
+              for (const auto& val : *task_spec->response_text_blocklist()) {
+                response_text_blocklist.insert(val->str());
+              }
+            }
+          }
           PopulateTextReplies(
               interpreter, suggestions_index, suggestions_scores_index,
               task_spec ? task_spec->type()->str()
                         : model_->smart_reply_action_type()->str(),
-              response);
+              task_spec ? task_spec->priority_score() : 0.0,
+              response_text_blocklist, response);
           break;
         case PredictionType_INTENT_TRIGGERING:
           PopulateIntentTriggering(interpreter, suggestions_index,
diff --git a/native/actions/actions-suggestions.h b/native/actions/actions-suggestions.h
index 32edc78..87f55fb 100644
--- a/native/actions/actions-suggestions.h
+++ b/native/actions/actions-suggestions.h
@@ -43,6 +43,7 @@
 #include "utils/utf8/unilib.h"
 #include "utils/variant.h"
 #include "utils/zlib/zlib.h"
+#include "absl/container/flat_hash_set.h"
 
 namespace libtextclassifier3 {
 
@@ -176,7 +177,8 @@
 
   void PopulateTextReplies(const tflite::Interpreter* interpreter,
                            int suggestion_index, int score_index,
-                           const std::string& type,
+                           const std::string& type, float priority_score,
+                           const absl::flat_hash_set<std::string>& blocklist,
                            ActionsSuggestionsResponse* response) const;
 
   void PopulateIntentTriggering(const tflite::Interpreter* interpreter,
diff --git a/native/actions/actions-suggestions_test.cc b/native/actions/actions-suggestions_test.cc
index 062d527..b51ebc7 100644
--- a/native/actions/actions-suggestions_test.cc
+++ b/native/actions/actions-suggestions_test.cc
@@ -1798,6 +1798,7 @@
 TEST_F(ActionsSuggestionsTest, SuggestsActionsFromMultiTaskSrEmojiModel) {
   std::unique_ptr<ActionsSuggestions> actions_suggestions =
       LoadTestModel(kMultiTaskSrEmojiModelFileName);
+
   const ActionsSuggestionsResponse response =
       actions_suggestions->SuggestActions(
           {{{/*user_id=*/1, "hello?",
@@ -1807,9 +1808,31 @@
              /*locales=*/"en"}}});
   EXPECT_EQ(response.actions.size(), 5);
   EXPECT_EQ(response.actions[0].response_text, "😁");
-  EXPECT_EQ(response.actions[0].type, "EMOJI_CONCEPT");
-  EXPECT_EQ(response.actions[1].response_text, "Yes");
-  EXPECT_EQ(response.actions[1].type, "REPLY_SUGGESTION");
+  EXPECT_EQ(response.actions[0].type, "text_reply");
+  EXPECT_EQ(response.actions[1].response_text, "👋");
+  EXPECT_EQ(response.actions[1].type, "text_reply");
+  EXPECT_EQ(response.actions[2].response_text, "Yes");
+  EXPECT_EQ(response.actions[2].type, "text_reply");
+}
+
+TEST_F(ActionsSuggestionsTest, MultiTaskSrEmojiModelRemovesTextHeadEmoji) {
+  std::unique_ptr<ActionsSuggestions> actions_suggestions =
+      LoadTestModel(kMultiTaskSrEmojiModelFileName);
+
+  const ActionsSuggestionsResponse response =
+      actions_suggestions->SuggestActions(
+          {{{/*user_id=*/1, "a pleasure chatting",
+             /*reference_time_ms_utc=*/0,
+             /*reference_timezone=*/"Europe/Zurich",
+             /*annotations=*/{},
+             /*locales=*/"en"}}});
+  EXPECT_EQ(response.actions.size(), 3);
+  EXPECT_EQ(response.actions[0].response_text, "😁");
+  EXPECT_EQ(response.actions[0].type, "text_reply");
+  EXPECT_EQ(response.actions[1].response_text, "😘");
+  EXPECT_EQ(response.actions[1].type, "text_reply");
+  EXPECT_EQ(response.actions[2].response_text, "Okay");
+  EXPECT_EQ(response.actions[2].type, "text_reply");
 }
 
 TEST_F(ActionsSuggestionsTest, LiveRelayModel) {
diff --git a/native/actions/actions_model.fbs b/native/actions/actions_model.fbs
index 8c03eeb..0d8c7ad 100644
--- a/native/actions/actions_model.fbs
+++ b/native/actions/actions_model.fbs
@@ -36,6 +36,17 @@
   ENTITY_ANNOTATION = 3,
 }
 
+namespace libtextclassifier3;
+enum RankingOptionsSortType : int {
+  SORT_TYPE_UNSPECIFIED = 0,
+
+  // Rank results (or groups) by score, then type
+  SORT_TYPE_SCORE = 1,
+
+  // Rank results (or groups) by priority score, then score, then type
+  SORT_TYPE_PRIORITY_SCORE = 2,
+}
+
 // Prediction metadata for an arbitrary task.
 namespace libtextclassifier3;
 table PredictionMetadata {
@@ -315,10 +326,11 @@
   // Additional entity information.
   serialized_entity_data:string (shared);
 
-  // Priority score used for internal conflict resolution.
+  // For ranking and internal conflict resolution.
   priority_score:float = 0;
 
   entity_data:ActionsEntityData;
+  response_text_blocklist:[string];
 }
 
 // Options to specify triggering behaviour per action class.
@@ -416,6 +428,8 @@
 
   // If true, keep actions from the same entities together for ranking.
   group_by_annotations:bool = true;
+
+  sort_type:RankingOptionsSortType = SORT_TYPE_SCORE;
 }
 
 // Entity data to set from capturing groups.
diff --git a/native/actions/ranker.cc b/native/actions/ranker.cc
index d52ecaa..46e392a 100644
--- a/native/actions/ranker.cc
+++ b/native/actions/ranker.cc
@@ -20,6 +20,8 @@
 #include <set>
 #include <vector>
 
+#include "actions/actions_model_generated.h"
+
 #if !defined(TC3_DISABLE_LUA)
 #include "actions/lua-ranker.h"
 #endif
@@ -34,11 +36,22 @@
 namespace {
 
 void SortByScoreAndType(std::vector<ActionSuggestion>* actions) {
-  std::sort(actions->begin(), actions->end(),
-            [](const ActionSuggestion& a, const ActionSuggestion& b) {
-              return a.score > b.score ||
-                     (a.score >= b.score && a.type < b.type);
-            });
+  std::stable_sort(actions->begin(), actions->end(),
+                   [](const ActionSuggestion& a, const ActionSuggestion& b) {
+                     return a.score > b.score ||
+                            (a.score >= b.score && a.type < b.type);
+                   });
+}
+
+void SortByPriorityAndScoreAndType(std::vector<ActionSuggestion>* actions) {
+  std::stable_sort(
+      actions->begin(), actions->end(),
+      [](const ActionSuggestion& a, const ActionSuggestion& b) {
+        return a.priority_score > b.priority_score ||
+               (a.priority_score >= b.priority_score && a.score > b.score) ||
+               (a.priority_score >= b.priority_score && a.score >= b.score &&
+                a.type < b.type);
+      });
 }
 
 template <typename T>
@@ -241,13 +254,8 @@
     const reflection::Schema* annotations_entity_data_schema) const {
   if (options_->deduplicate_suggestions() ||
       options_->deduplicate_suggestions_by_span()) {
-    // First order suggestions by priority score for deduplication.
-    std::sort(
-        response->actions.begin(), response->actions.end(),
-        [](const ActionSuggestion& a, const ActionSuggestion& b) {
-          return a.priority_score > b.priority_score ||
-                 (a.priority_score >= b.priority_score && a.score > b.score);
-        });
+    // Order suggestions by [priority score -> score] for deduplication
+    SortByPriorityAndScoreAndType(&response->actions);
 
     // Deduplicate, keeping the higher score actions.
     if (options_->deduplicate_suggestions()) {
@@ -275,6 +283,8 @@
     }
   }
 
+  bool sort_by_priority =
+      options_->sort_type() == RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE;
   // Suppress smart replies if actions are present.
   if (options_->suppress_smart_replies_with_actions()) {
     std::vector<ActionSuggestion> non_smart_reply_actions;
@@ -316,17 +326,35 @@
 
     // Sort within each group by score.
     for (std::vector<ActionSuggestion>& group : groups) {
-      SortByScoreAndType(&group);
+      if (sort_by_priority) {
+        SortByPriorityAndScoreAndType(&group);
+      } else {
+        SortByScoreAndType(&group);
+      }
     }
 
-    // Sort groups by maximum score.
-    std::sort(groups.begin(), groups.end(),
-              [](const std::vector<ActionSuggestion>& a,
-                 const std::vector<ActionSuggestion>& b) {
-                return a.begin()->score > b.begin()->score ||
-                       (a.begin()->score >= b.begin()->score &&
-                        a.begin()->type < b.begin()->type);
-              });
+    // Sort groups by maximum score or priority score.
+    if (sort_by_priority) {
+      std::stable_sort(
+          groups.begin(), groups.end(),
+          [](const std::vector<ActionSuggestion>& a,
+             const std::vector<ActionSuggestion>& b) {
+            return (a.begin()->priority_score > b.begin()->priority_score) ||
+                   (a.begin()->priority_score >= b.begin()->priority_score &&
+                    a.begin()->score > b.begin()->score) ||
+                   (a.begin()->priority_score >= b.begin()->priority_score &&
+                    a.begin()->score >= b.begin()->score &&
+                    a.begin()->type < b.begin()->type);
+          });
+    } else {
+      std::stable_sort(groups.begin(), groups.end(),
+                       [](const std::vector<ActionSuggestion>& a,
+                          const std::vector<ActionSuggestion>& b) {
+                         return a.begin()->score > b.begin()->score ||
+                                (a.begin()->score >= b.begin()->score &&
+                                 a.begin()->type < b.begin()->type);
+                       });
+    }
 
     // Flatten result.
     const size_t num_actions = response->actions.size();
@@ -336,9 +364,9 @@
       response->actions.insert(response->actions.end(), actions.begin(),
                                actions.end());
     }
-
+  } else if (sort_by_priority) {
+    SortByPriorityAndScoreAndType(&response->actions);
   } else {
-    // Order suggestions independently by score.
     SortByScoreAndType(&response->actions);
   }
 
diff --git a/native/actions/ranker_test.cc b/native/actions/ranker_test.cc
index b52cf45..5eba45f 100644
--- a/native/actions/ranker_test.cc
+++ b/native/actions/ranker_test.cc
@@ -18,6 +18,7 @@
 
 #include <string>
 
+#include "actions/actions_model_generated.h"
 #include "actions/types.h"
 #include "utils/zlib/zlib.h"
 #include "gmock/gmock.h"
@@ -308,12 +309,12 @@
     response.actions.push_back({/*response_text=*/"",
                                 /*type=*/"call_phone",
                                 /*score=*/1.0,
-                                /*priority_score=*/1.0,
+                                /*priority_score=*/0.0,
                                 /*annotations=*/{annotation}});
     response.actions.push_back({/*response_text=*/"",
                                 /*type=*/"add_contact",
                                 /*score=*/0.0,
-                                /*priority_score=*/0.0,
+                                /*priority_score=*/1.0,
                                 /*annotations=*/{annotation}});
   }
   response.actions.push_back({/*response_text=*/"How are you?",
@@ -338,6 +339,58 @@
                                  IsAction("text_reply", "How are you?", 0.5)}));
 }
 
+TEST(RankingTest, GroupsByAnnotationsSortedByPriority) {
+  const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
+  ActionsSuggestionsResponse response;
+  response.actions.push_back({/*response_text=*/"How are you?",
+                              /*type=*/"text_reply",
+                              /*score=*/2.0,
+                              /*priority_score=*/0.0});
+  {
+    ActionSuggestionAnnotation annotation;
+    annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
+                       /*text=*/"911"};
+    annotation.entity = ClassificationResult("phone", 1.0);
+    response.actions.push_back({/*response_text=*/"",
+                                /*type=*/"add_contact",
+                                /*score=*/0.0,
+                                /*priority_score=*/1.0,
+                                /*annotations=*/{annotation}});
+    response.actions.push_back({/*response_text=*/"",
+                                /*type=*/"call_phone",
+                                /*score=*/1.0,
+                                /*priority_score=*/0.0,
+                                /*annotations=*/{annotation}});
+    response.actions.push_back({/*response_text=*/"",
+                                /*type=*/"add_contact2",
+                                /*score=*/0.5,
+                                /*priority_score=*/1.0,
+                                /*annotations=*/{annotation}});
+  }
+  RankingOptionsT options;
+  options.group_by_annotations = true;
+  options.sort_type = RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE;
+  flatbuffers::FlatBufferBuilder builder;
+  builder.Finish(RankingOptions::Pack(builder, &options));
+  auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
+      flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
+      /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
+
+  ranker->RankActions(conversation, &response);
+
+  // The text reply should be last, even though it's score is higher than
+  // any other scores -- because it's priority_score is lower than the max
+  // of those with the 'phone' annotation
+  EXPECT_THAT(response.actions,
+              testing::ElementsAreArray({
+                  // Group 1 (Phone annotation)
+                  IsAction("add_contact2", "", 0.5),  // priority_score=1.0
+                  IsAction("add_contact", "", 0.0),   // priority_score=1.0
+                  IsAction("call_phone", "", 1.0),    // priority_score=0.0
+                  IsAction("text_reply", "How are you?", 2.0),  // Group 2
+              }));
+}
+
 TEST(RankingTest, SortsActionsByScore) {
   const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
   ActionsSuggestionsResponse response;
@@ -349,12 +402,12 @@
     response.actions.push_back({/*response_text=*/"",
                                 /*type=*/"call_phone",
                                 /*score=*/1.0,
-                                /*priority_score=*/1.0,
+                                /*priority_score=*/0.0,
                                 /*annotations=*/{annotation}});
     response.actions.push_back({/*response_text=*/"",
                                 /*type=*/"add_contact",
                                 /*score=*/0.0,
-                                /*priority_score=*/0.0,
+                                /*priority_score=*/1.0,
                                 /*annotations=*/{annotation}});
   }
   response.actions.push_back({/*response_text=*/"How are you?",
@@ -378,5 +431,40 @@
                                  IsAction("add_contact", "", 0.0)}));
 }
 
+TEST(RankingTest, SortsActionsByPriority) {
+  const Conversation conversation = {{{/*user_id=*/1, "hello?"}}};
+  ActionsSuggestionsResponse response;
+  // emoji replies given higher priority_score
+  response.actions.push_back({/*response_text=*/"😁",
+                              /*type=*/"text_reply",
+                              /*score=*/0.5,
+                              /*priority_score=*/1.0});
+  response.actions.push_back({/*response_text=*/"👋",
+                              /*type=*/"text_reply",
+                              /*score=*/0.4,
+                              /*priority_score=*/1.0});
+  response.actions.push_back({/*response_text=*/"Yes",
+                              /*type=*/"text_reply",
+                              /*score=*/1.0,
+                              /*priority_score=*/0.0});
+  RankingOptionsT options;
+  // Don't group by annotation.
+  options.group_by_annotations = false;
+  options.sort_type = RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE;
+  flatbuffers::FlatBufferBuilder builder;
+  builder.Finish(RankingOptions::Pack(builder, &options));
+  auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
+      flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
+      /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
+
+  ranker->RankActions(conversation, &response);
+
+  EXPECT_THAT(response.actions, testing::ElementsAreArray(
+                                    {IsAction("text_reply", "😁", 0.5),
+                                     IsAction("text_reply", "👋", 0.4),
+                                     // Ranked last because of priority score
+                                     IsAction("text_reply", "Yes", 1.0)}));
+}
+
 }  // namespace
 }  // namespace libtextclassifier3
diff --git a/native/actions/test_data/actions_suggestions_grammar_test.model b/native/actions/test_data/actions_suggestions_grammar_test.model
index 77e556c..0fa7f7e 100644
--- a/native/actions/test_data/actions_suggestions_grammar_test.model
+++ b/native/actions/test_data/actions_suggestions_grammar_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.model b/native/actions/test_data/actions_suggestions_test.model
index c468bd5..6107e98 100644
--- a/native/actions/test_data/actions_suggestions_test.model
+++ b/native/actions/test_data/actions_suggestions_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
index ec421a1..436ed93 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model
index 24be6c6..935691d 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
index fd7ddf2..2c9f74b 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model
index c969c56..cdb7523 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model b/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model
index d171898..ac28fa2 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model b/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model
index 937552b..d864b79 100644
--- a/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model
+++ b/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model
Binary files differ
diff --git a/native/annotator/annotator.cc b/native/annotator/annotator.cc
index 32bd29c..e0d4241 100644
--- a/native/annotator/annotator.cc
+++ b/native/annotator/annotator.cc
@@ -973,11 +973,11 @@
   // Sort candidates according to their position in the input, so that the next
   // code can assume that any connected component of overlapping spans forms a
   // contiguous block.
-  std::sort(candidates.annotated_spans[0].begin(),
-            candidates.annotated_spans[0].end(),
-            [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
-              return a.span.first < b.span.first;
-            });
+  std::stable_sort(candidates.annotated_spans[0].begin(),
+                   candidates.annotated_spans[0].end(),
+                   [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
+                     return a.span.first < b.span.first;
+                   });
 
   std::vector<int> candidate_indices;
   if (!ResolveConflicts(candidates.annotated_spans[0], context, tokens,
@@ -987,13 +987,14 @@
     return original_click_indices;
   }
 
-  std::sort(candidate_indices.begin(), candidate_indices.end(),
-            [this, &candidates](int a, int b) {
-              return GetPriorityScore(
-                         candidates.annotated_spans[0][a].classification) >
-                     GetPriorityScore(
-                         candidates.annotated_spans[0][b].classification);
-            });
+  std::stable_sort(
+      candidate_indices.begin(), candidate_indices.end(),
+      [this, &candidates](int a, int b) {
+        return GetPriorityScore(
+                   candidates.annotated_spans[0][a].classification) >
+               GetPriorityScore(
+                   candidates.annotated_spans[0][b].classification);
+      });
 
   for (const int i : candidate_indices) {
     if (SpansOverlap(candidates.annotated_spans[0][i].span, click_indices) &&
@@ -1173,7 +1174,7 @@
     }
   }
 
-  std::sort(
+  std::stable_sort(
       conflicting_indices.begin(), conflicting_indices.end(),
       [this, &scores_lengths, candidates, conflicting_indices](int i, int j) {
         if (scores_lengths[i].first == scores_lengths[j].first &&
@@ -1241,7 +1242,7 @@
     chosen_indices_for_source_ptr->insert(considered_candidate);
   }
 
-  std::sort(chosen_indices->begin(), chosen_indices->end());
+  std::stable_sort(chosen_indices->begin(), chosen_indices->end());
 
   return true;
 }
@@ -1414,10 +1415,11 @@
 // Sorts the classification results from high score to low score.
 void SortClassificationResults(
     std::vector<ClassificationResult>* classification_results) {
-  std::sort(classification_results->begin(), classification_results->end(),
-            [](const ClassificationResult& a, const ClassificationResult& b) {
-              return a.score > b.score;
-            });
+  std::stable_sort(
+      classification_results->begin(), classification_results->end(),
+      [](const ClassificationResult& a, const ClassificationResult& b) {
+        return a.score > b.score;
+      });
 }
 }  // namespace
 
@@ -1936,10 +1938,11 @@
   }
 
   // Sort results according to score.
-  std::sort(results.begin(), results.end(),
-            [](const ClassificationResult& a, const ClassificationResult& b) {
-              return a.score > b.score;
-            });
+  std::stable_sort(
+      results.begin(), results.end(),
+      [](const ClassificationResult& a, const ClassificationResult& b) {
+        return a.score > b.score;
+      });
 
   if (results.empty()) {
     results = {{Collections::Other(), 1.0}};
@@ -2297,19 +2300,19 @@
   // Also sort them according to the end position and collection, so that the
   // deduplication code below can assume that same spans and classifications
   // form contiguous blocks.
-  std::sort(candidates->begin(), candidates->end(),
-            [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
-              if (a.span.first != b.span.first) {
-                return a.span.first < b.span.first;
-              }
+  std::stable_sort(candidates->begin(), candidates->end(),
+                   [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
+                     if (a.span.first != b.span.first) {
+                       return a.span.first < b.span.first;
+                     }
 
-              if (a.span.second != b.span.second) {
-                return a.span.second < b.span.second;
-              }
+                     if (a.span.second != b.span.second) {
+                       return a.span.second < b.span.second;
+                     }
 
-              return a.classification[0].collection <
-                     b.classification[0].collection;
-            });
+                     return a.classification[0].collection <
+                            b.classification[0].collection;
+                   });
 
   std::vector<int> candidate_indices;
   if (!ResolveConflicts(*candidates, context, tokens,
@@ -2904,10 +2907,10 @@
       return false;
     }
   }
-  std::sort(scored_chunks.rbegin(), scored_chunks.rend(),
-            [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
-              return lhs.score < rhs.score;
-            });
+  std::stable_sort(scored_chunks.rbegin(), scored_chunks.rend(),
+                   [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
+                     return lhs.score < rhs.score;
+                   });
 
   // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
   // them greedily as long as they do not overlap with any previously picked
@@ -2936,7 +2939,7 @@
     chunks->push_back(scored_chunk.token_span);
   }
 
-  std::sort(chunks->begin(), chunks->end());
+  std::stable_sort(chunks->begin(), chunks->end());
 
   return true;
 }
diff --git a/native/annotator/datetime/datetime-grounder.cc b/native/annotator/datetime/datetime-grounder.cc
index 7d5f440..ff0c775 100644
--- a/native/annotator/datetime/datetime-grounder.cc
+++ b/native/annotator/datetime/datetime-grounder.cc
@@ -16,6 +16,7 @@
 
 #include "annotator/datetime/datetime-grounder.h"
 
+#include <algorithm>
 #include <limits>
 #include <unordered_map>
 #include <vector>
@@ -250,10 +251,10 @@
     }
 
     // Sort the date time units by component type.
-    std::sort(date_components.begin(), date_components.end(),
-              [](DatetimeComponent a, DatetimeComponent b) {
-                return a.component_type > b.component_type;
-              });
+    std::stable_sort(date_components.begin(), date_components.end(),
+                     [](DatetimeComponent a, DatetimeComponent b) {
+                       return a.component_type > b.component_type;
+                     });
     result.datetime_components.swap(date_components);
     datetime_parse_result.push_back(result);
   }
diff --git a/native/annotator/datetime/extractor.cc b/native/annotator/datetime/extractor.cc
index 867c886..94a0961 100644
--- a/native/annotator/datetime/extractor.cc
+++ b/native/annotator/datetime/extractor.cc
@@ -16,6 +16,8 @@
 
 #include "annotator/datetime/extractor.h"
 
+#include <algorithm>
+
 #include "annotator/datetime/utils.h"
 #include "annotator/model_generated.h"
 #include "annotator/types.h"
@@ -347,10 +349,11 @@
     }
   }
 
-  std::sort(found_numbers.begin(), found_numbers.end(),
-            [](const std::pair<int, int>& a, const std::pair<int, int>& b) {
-              return a.first < b.first;
-            });
+  std::stable_sort(
+      found_numbers.begin(), found_numbers.end(),
+      [](const std::pair<int, int>& a, const std::pair<int, int>& b) {
+        return a.first < b.first;
+      });
 
   int sum = 0;
   int running_value = -1;
diff --git a/native/annotator/datetime/regex-parser.cc b/native/annotator/datetime/regex-parser.cc
index 4dc9c56..5daabd5 100644
--- a/native/annotator/datetime/regex-parser.cc
+++ b/native/annotator/datetime/regex-parser.cc
@@ -16,6 +16,7 @@
 
 #include "annotator/datetime/regex-parser.h"
 
+#include <algorithm>
 #include <iterator>
 #include <set>
 #include <unordered_set>
@@ -191,17 +192,17 @@
 
   // Resolve conflicts by always picking the longer span and breaking ties by
   // selecting the earlier entry in the list for a given locale.
-  std::sort(indexed_found_spans.begin(), indexed_found_spans.end(),
-            [](const std::pair<DatetimeParseResultSpan, int>& a,
-               const std::pair<DatetimeParseResultSpan, int>& b) {
-              if ((a.first.span.second - a.first.span.first) !=
-                  (b.first.span.second - b.first.span.first)) {
-                return (a.first.span.second - a.first.span.first) >
-                       (b.first.span.second - b.first.span.first);
-              } else {
-                return a.second < b.second;
-              }
-            });
+  std::stable_sort(indexed_found_spans.begin(), indexed_found_spans.end(),
+                   [](const std::pair<DatetimeParseResultSpan, int>& a,
+                      const std::pair<DatetimeParseResultSpan, int>& b) {
+                     if ((a.first.span.second - a.first.span.first) !=
+                         (b.first.span.second - b.first.span.first)) {
+                       return (a.first.span.second - a.first.span.first) >
+                              (b.first.span.second - b.first.span.first);
+                     } else {
+                       return a.second < b.second;
+                     }
+                   });
 
   std::vector<DatetimeParseResultSpan> results;
   std::vector<DatetimeParseResultSpan> resolved_found_spans;
@@ -394,10 +395,10 @@
     }
 
     // Sort the date time units by component type.
-    std::sort(date_components.begin(), date_components.end(),
-              [](DatetimeComponent a, DatetimeComponent b) {
-                return a.component_type > b.component_type;
-              });
+    std::stable_sort(date_components.begin(), date_components.end(),
+                     [](DatetimeComponent a, DatetimeComponent b) {
+                       return a.component_type > b.component_type;
+                     });
     result.datetime_components.swap(date_components);
     results->push_back(result);
   }
diff --git a/native/annotator/translate/translate.cc b/native/annotator/translate/translate.cc
index 640ceec..2c5a43c 100644
--- a/native/annotator/translate/translate.cc
+++ b/native/annotator/translate/translate.cc
@@ -16,6 +16,7 @@
 
 #include "annotator/translate/translate.h"
 
+#include <algorithm>
 #include <memory>
 
 #include "annotator/collections.h"
@@ -142,11 +143,11 @@
     result.push_back({key, value});
   }
 
-  std::sort(result.begin(), result.end(),
-            [](TranslateAnnotator::LanguageConfidence& a,
-               TranslateAnnotator::LanguageConfidence& b) {
-              return a.confidence > b.confidence;
-            });
+  std::stable_sort(result.begin(), result.end(),
+                   [](const TranslateAnnotator::LanguageConfidence& a,
+                      const TranslateAnnotator::LanguageConfidence& b) {
+                     return a.confidence > b.confidence;
+                   });
   return result;
 }
 
diff --git a/native/lang_id/common/embedding-network.cc b/native/lang_id/common/embedding-network.cc
index 469cb1f..49c9ca0 100644
--- a/native/lang_id/common/embedding-network.cc
+++ b/native/lang_id/common/embedding-network.cc
@@ -16,6 +16,8 @@
 
 #include "lang_id/common/embedding-network.h"
 
+#include <vector>
+
 #include "lang_id/common/lite_base/integral-types.h"
 #include "lang_id/common/lite_base/logging.h"
 
diff --git a/native/lang_id/common/fel/feature-extractor.cc b/native/lang_id/common/fel/feature-extractor.cc
index ab8a1a6..4e304fe 100644
--- a/native/lang_id/common/fel/feature-extractor.cc
+++ b/native/lang_id/common/fel/feature-extractor.cc
@@ -17,6 +17,7 @@
 #include "lang_id/common/fel/feature-extractor.h"
 
 #include <string>
+#include <vector>
 
 #include "lang_id/common/fel/feature-types.h"
 #include "lang_id/common/fel/fel-parser.h"
diff --git a/native/lang_id/common/fel/workspace.cc b/native/lang_id/common/fel/workspace.cc
index af41e29..60dcc46 100644
--- a/native/lang_id/common/fel/workspace.cc
+++ b/native/lang_id/common/fel/workspace.cc
@@ -18,6 +18,7 @@
 
 #include <atomic>
 #include <string>
+#include <vector>
 
 namespace libtextclassifier3 {
 namespace mobile {
diff --git a/native/lang_id/common/fel/workspace.h b/native/lang_id/common/fel/workspace.h
index f13d802..2ac5b26 100644
--- a/native/lang_id/common/fel/workspace.h
+++ b/native/lang_id/common/fel/workspace.h
@@ -23,6 +23,7 @@
 
 #include <stddef.h>
 
+#include <algorithm>
 #include <string>
 #include <unordered_map>
 #include <utility>
diff --git a/native/lang_id/common/file/mmap.cc b/native/lang_id/common/file/mmap.cc
index 19afcc4..fc925ea 100644
--- a/native/lang_id/common/file/mmap.cc
+++ b/native/lang_id/common/file/mmap.cc
@@ -29,6 +29,8 @@
 #endif
 #include <sys/stat.h>
 
+#include <string>
+
 #include "lang_id/common/lite_base/logging.h"
 #include "lang_id/common/lite_base/macros.h"
 
diff --git a/native/lang_id/common/lite_strings/str-split.cc b/native/lang_id/common/lite_strings/str-split.cc
index 199bb69..d227eec 100644
--- a/native/lang_id/common/lite_strings/str-split.cc
+++ b/native/lang_id/common/lite_strings/str-split.cc
@@ -16,6 +16,8 @@
 
 #include "lang_id/common/lite_strings/str-split.h"
 
+#include <vector>
+
 namespace libtextclassifier3 {
 namespace mobile {
 
diff --git a/native/lang_id/common/math/softmax.cc b/native/lang_id/common/math/softmax.cc
index 750341d..249ed57 100644
--- a/native/lang_id/common/math/softmax.cc
+++ b/native/lang_id/common/math/softmax.cc
@@ -17,6 +17,7 @@
 #include "lang_id/common/math/softmax.h"
 
 #include <algorithm>
+#include <vector>
 
 #include "lang_id/common/lite_base/logging.h"
 #include "lang_id/common/math/fastexp.h"
diff --git a/native/lang_id/fb_model/lang-id-from-fb.cc b/native/lang_id/fb_model/lang-id-from-fb.cc
index dc36fb7..51c8c47 100644
--- a/native/lang_id/fb_model/lang-id-from-fb.cc
+++ b/native/lang_id/fb_model/lang-id-from-fb.cc
@@ -16,7 +16,9 @@
 
 #include "lang_id/fb_model/lang-id-from-fb.h"
 
+#include <memory>
 #include <string>
+#include <utility>
 
 #include "lang_id/fb_model/model-provider-from-fb.h"
 
diff --git a/native/lang_id/fb_model/model-provider-from-fb.cc b/native/lang_id/fb_model/model-provider-from-fb.cc
index 43bf860..d14d403 100644
--- a/native/lang_id/fb_model/model-provider-from-fb.cc
+++ b/native/lang_id/fb_model/model-provider-from-fb.cc
@@ -16,7 +16,9 @@
 
 #include "lang_id/fb_model/model-provider-from-fb.h"
 
+#include <memory>
 #include <string>
+#include <utility>
 
 #include "lang_id/common/file/file-utils.h"
 #include "lang_id/common/file/mmap.h"
diff --git a/native/lang_id/lang-id.cc b/native/lang_id/lang-id.cc
index 92359a9..f7c66f7 100644
--- a/native/lang_id/lang-id.cc
+++ b/native/lang_id/lang-id.cc
@@ -21,6 +21,7 @@
 #include <memory>
 #include <string>
 #include <unordered_map>
+#include <utility>
 #include <vector>
 
 #include "lang_id/common/embedding-feature-interface.h"
diff --git a/native/utils/codepoint-range.cc b/native/utils/codepoint-range.cc
index e26b160..a4cd485 100644
--- a/native/utils/codepoint-range.cc
+++ b/native/utils/codepoint-range.cc
@@ -31,10 +31,11 @@
         CodepointRangeStruct(range->start(), range->end()));
   }
 
-  std::sort(sorted_codepoint_ranges->begin(), sorted_codepoint_ranges->end(),
-            [](const CodepointRangeStruct& a, const CodepointRangeStruct& b) {
-              return a.start < b.start;
-            });
+  std::stable_sort(
+      sorted_codepoint_ranges->begin(), sorted_codepoint_ranges->end(),
+      [](const CodepointRangeStruct& a, const CodepointRangeStruct& b) {
+        return a.start < b.start;
+      });
 }
 
 // Returns true if given codepoint is covered by the given sorted vector of
diff --git a/native/utils/grammar/parsing/parser.cc b/native/utils/grammar/parsing/parser.cc
index 4e39a98..a9e99ba 100644
--- a/native/utils/grammar/parsing/parser.cc
+++ b/native/utils/grammar/parsing/parser.cc
@@ -16,6 +16,7 @@
 
 #include "utils/grammar/parsing/parser.h"
 
+#include <algorithm>
 #include <unordered_map>
 
 #include "utils/grammar/parsing/parse-tree.h"
@@ -177,14 +178,14 @@
     }
   }
 
-  std::sort(symbols.begin(), symbols.end(),
-            [](const Symbol& a, const Symbol& b) {
-              // Sort by increasing (end, start) position to guarantee the
-              // matcher requirement that the tokens are fed in non-decreasing
-              // end position order.
-              return std::tie(a.codepoint_span.second, a.codepoint_span.first) <
-                     std::tie(b.codepoint_span.second, b.codepoint_span.first);
-            });
+  std::stable_sort(
+      symbols.begin(), symbols.end(), [](const Symbol& a, const Symbol& b) {
+        // Sort by increasing (end, start) position to guarantee the
+        // matcher requirement that the tokens are fed in non-decreasing
+        // end position order.
+        return std::tie(a.codepoint_span.second, a.codepoint_span.first) <
+               std::tie(b.codepoint_span.second, b.codepoint_span.first);
+      });
 
   return symbols;
 }
diff --git a/native/utils/grammar/utils/ir.cc b/native/utils/grammar/utils/ir.cc
index dd29e3c..c134550 100644
--- a/native/utils/grammar/utils/ir.cc
+++ b/native/utils/grammar/utils/ir.cc
@@ -16,6 +16,8 @@
 
 #include "utils/grammar/utils/ir.h"
 
+#include <algorithm>
+
 #include "utils/i18n/locale.h"
 #include "utils/strings/append.h"
 #include "utils/strings/stringpiece.h"
@@ -28,14 +30,16 @@
 
 template <typename T>
 void SortForBinarySearchLookup(T* entries) {
-  std::sort(entries->begin(), entries->end(),
-            [](const auto& a, const auto& b) { return a->key < b->key; });
+  std::stable_sort(
+      entries->begin(), entries->end(),
+      [](const auto& a, const auto& b) { return a->key < b->key; });
 }
 
 template <typename T>
 void SortStructsForBinarySearchLookup(T* entries) {
-  std::sort(entries->begin(), entries->end(),
-            [](const auto& a, const auto& b) { return a.key() < b.key(); });
+  std::stable_sort(
+      entries->begin(), entries->end(),
+      [](const auto& a, const auto& b) { return a.key() < b.key(); });
 }
 
 bool IsSameLhs(const Ir::Lhs& lhs, const RulesSet_::Lhs& other) {
@@ -76,13 +80,14 @@
 
 Ir::LhsSet SortedLhsSet(const Ir::LhsSet& lhs_set) {
   Ir::LhsSet sorted_lhs = lhs_set;
-  std::sort(sorted_lhs.begin(), sorted_lhs.end(),
-            [](const Ir::Lhs& a, const Ir::Lhs& b) {
-              return std::tie(a.nonterminal, a.callback.id, a.callback.param,
-                              a.preconditions.max_whitespace_gap) <
-                     std::tie(b.nonterminal, b.callback.id, b.callback.param,
-                              b.preconditions.max_whitespace_gap);
-            });
+  std::stable_sort(
+      sorted_lhs.begin(), sorted_lhs.end(),
+      [](const Ir::Lhs& a, const Ir::Lhs& b) {
+        return std::tie(a.nonterminal, a.callback.id, a.callback.param,
+                        a.preconditions.max_whitespace_gap) <
+               std::tie(b.nonterminal, b.callback.id, b.callback.param,
+                        b.preconditions.max_whitespace_gap);
+      });
   return lhs_set;
 }
 
@@ -300,10 +305,10 @@
           TerminalEntry{it.first, /*set_index=*/i, /*index=*/0, it.second});
     }
   }
-  std::sort(terminal_rules.begin(), terminal_rules.end(),
-            [](const TerminalEntry& a, const TerminalEntry& b) {
-              return a.terminal < b.terminal;
-            });
+  std::stable_sort(terminal_rules.begin(), terminal_rules.end(),
+                   [](const TerminalEntry& a, const TerminalEntry& b) {
+                     return a.terminal < b.terminal;
+                   });
 
   // Index the entries in sorted order.
   std::vector<int> index(terminal_rules_sets.size(), 0);
diff --git a/native/utils/grammar/utils/locale-shard-map.cc b/native/utils/grammar/utils/locale-shard-map.cc
index e6db06d..141ce5d 100644
--- a/native/utils/grammar/utils/locale-shard-map.cc
+++ b/native/utils/grammar/utils/locale-shard-map.cc
@@ -40,8 +40,8 @@
       locale_list.emplace_back(locale);
     }
   }
-  std::sort(locale_list.begin(), locale_list.end(),
-            [](const Locale& a, const Locale& b) { return a < b; });
+  std::stable_sort(locale_list.begin(), locale_list.end(),
+                   [](const Locale& a, const Locale& b) { return a < b; });
   return locale_list;
 }
 
diff --git a/native/utils/testing/test_data_generator.h b/native/utils/testing/test_data_generator.h
index 30c7aed..c23b5dc 100644
--- a/native/utils/testing/test_data_generator.h
+++ b/native/utils/testing/test_data_generator.h
@@ -20,6 +20,7 @@
 #include <algorithm>
 #include <iostream>
 #include <random>
+#include <string>
 
 #include "utils/strings/stringpiece.h"
 
@@ -35,6 +36,18 @@
     return dist(random_engine_);
   }
 
+  template <>
+  bool generate() {
+    std::bernoulli_distribution dist(0.5);
+    return dist(random_engine_);
+  }
+
+  template <>
+  char generate() {
+    std::uniform_int_distribution<int> dist(0, 25);
+    return dist(random_engine_) + 'a';
+  }
+
   template <typename T, typename std::enable_if_t<
                             std::is_floating_point<T>::value>* = nullptr>
   T generate() {
diff --git a/native/utils/tflite-model-executor.cc b/native/utils/tflite-model-executor.cc
index 463d910..644dde8 100644
--- a/native/utils/tflite-model-executor.cc
+++ b/native/utils/tflite-model-executor.cc
@@ -27,6 +27,8 @@
 TfLiteRegistration* Register_ADD();
 TfLiteRegistration* Register_CONCATENATION();
 TfLiteRegistration* Register_CONV_2D();
+TfLiteRegistration* Register_DEPTHWISE_CONV_2D();
+TfLiteRegistration* Register_AVERAGE_POOL_2D();
 TfLiteRegistration* Register_EQUAL();
 TfLiteRegistration* Register_FULLY_CONNECTED();
 TfLiteRegistration* Register_GREATER_EQUAL();
@@ -89,7 +91,9 @@
 #include "utils/tflite/dist_diversification.h"
 #include "utils/tflite/string_projection.h"
 #include "utils/tflite/text_encoder.h"
+#include "utils/tflite/text_encoder3s.h"
 #include "utils/tflite/token_encoder.h"
+
 namespace tflite {
 namespace ops {
 namespace custom {
@@ -114,6 +118,14 @@
                        tflite::ops::builtin::Register_CONV_2D(),
                        /*min_version=*/1,
                        /*max_version=*/5);
+  resolver->AddBuiltin(tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
+                       tflite::ops::builtin::Register_DEPTHWISE_CONV_2D(),
+                       /*min_version=*/1,
+                       /*max_version=*/6);
+  resolver->AddBuiltin(tflite::BuiltinOperator_AVERAGE_POOL_2D,
+                       tflite::ops::builtin::Register_AVERAGE_POOL_2D(),
+                       /*min_version=*/1,
+                       /*max_version=*/1);
   resolver->AddBuiltin(::tflite::BuiltinOperator_EQUAL,
                        ::tflite::ops::builtin::Register_EQUAL());
 
@@ -289,6 +301,8 @@
                       tflite::ops::custom::Register_DISTANCE_DIVERSIFICATION());
   resolver->AddCustom("TextEncoder",
                       tflite::ops::custom::Register_TEXT_ENCODER());
+  resolver->AddCustom("TextEncoder3S",
+                      tflite::ops::custom::Register_TEXT_ENCODER3S());
   resolver->AddCustom("TokenEncoder",
                       tflite::ops::custom::Register_TOKEN_ENCODER());
   resolver->AddCustom(
diff --git a/native/utils/tflite/encoder_common.cc b/native/utils/tflite/encoder_common.cc
index 8f9f2a8..eb319f9 100644
--- a/native/utils/tflite/encoder_common.cc
+++ b/native/utils/tflite/encoder_common.cc
@@ -58,6 +58,11 @@
                   out->data.i32 + output_offset + from_this_element,
                   in.data.i32[value_index]);
       } break;
+      case kTfLiteInt64: {
+        std::fill(out->data.i64 + output_offset,
+                  out->data.i64 + output_offset + from_this_element,
+                  in.data.i64[value_index]);
+      } break;
       case kTfLiteFloat32: {
         std::fill(out->data.f + output_offset,
                   out->data.f + output_offset + from_this_element,
@@ -78,6 +83,12 @@
       std::fill(out->data.i32 + output_offset, out->data.i32 + output_size,
                 value);
     } break;
+    case kTfLiteInt64: {
+      const int64_t value =
+          (output_offset > 0) ? out->data.i64[output_offset - 1] : 0;
+      std::fill(out->data.i64 + output_offset, out->data.i64 + output_size,
+                value);
+    } break;
     case kTfLiteFloat32: {
       const float value =
           (output_offset > 0) ? out->data.f[output_offset - 1] : 0;
diff --git a/native/utils/tflite/text_encoder3s.cc b/native/utils/tflite/text_encoder3s.cc
new file mode 100644
index 0000000..0b5e65b
--- /dev/null
+++ b/native/utils/tflite/text_encoder3s.cc
@@ -0,0 +1,243 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/tflite/text_encoder3s.h"
+
+#include <memory>
+#include <vector>
+
+#include "utils/base/logging.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/tflite/encoder_common.h"
+#include "utils/tflite/text_encoder_config_generated.h"
+#include "utils/tokenfree/byte_encoder.h"
+#include "flatbuffers/flatbuffers.h"
+#include "flatbuffers/flexbuffers.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/model.h"
+#include "tensorflow/lite/string_util.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+// Input parameters for the op.
+constexpr int kInputTextInd = 0;
+
+constexpr int kTextLengthInd = 1;
+constexpr int kMaxLengthInd = 2;
+constexpr int kInputAttrInd = 3;
+
+// Output parameters for the op.
+constexpr int kOutputEncodedInd = 0;
+constexpr int kOutputPositionInd = 1;
+constexpr int kOutputLengthsInd = 2;
+constexpr int kOutputAttrInd = 3;
+
+// Initializes text encoder object from serialized parameters.
+void* Initialize(TfLiteContext* context, const char* buffer, size_t length) {
+  std::unique_ptr<ByteEncoder> encoder(new ByteEncoder());
+  return encoder.release();
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+  delete reinterpret_cast<ByteEncoder*>(buffer);
+}
+
+namespace {
+TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
+                                 int max_output_length) {
+  TfLiteTensor& output_encoded =
+      context->tensors[node->outputs->data[kOutputEncodedInd]];
+
+  TF_LITE_ENSURE_OK(
+      context, context->ResizeTensor(
+                   context, &output_encoded,
+                   CreateIntArray({kEncoderBatchSize, max_output_length})));
+  TfLiteTensor& output_positions =
+      context->tensors[node->outputs->data[kOutputPositionInd]];
+
+  TF_LITE_ENSURE_OK(
+      context, context->ResizeTensor(
+                   context, &output_positions,
+                   CreateIntArray({kEncoderBatchSize, max_output_length})));
+
+  const int num_output_attrs = node->outputs->size - kOutputAttrInd;
+  for (int i = 0; i < num_output_attrs; ++i) {
+    TfLiteTensor& output =
+        context->tensors[node->outputs->data[kOutputAttrInd + i]];
+    TF_LITE_ENSURE_OK(
+        context, context->ResizeTensor(
+                     context, &output,
+                     CreateIntArray({kEncoderBatchSize, max_output_length})));
+  }
+  return kTfLiteOk;
+}
+}  // namespace
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+  // Check that the batch dimension is kEncoderBatchSize.
+  const TfLiteTensor& input_text =
+      context->tensors[node->inputs->data[kInputTextInd]];
+  TF_LITE_ENSURE_EQ(context, input_text.dims->size, kEncoderInputRank);
+  TF_LITE_ENSURE_EQ(context, input_text.dims->data[0], kEncoderBatchSize);
+
+  TfLiteTensor& output_lengths =
+      context->tensors[node->outputs->data[kOutputLengthsInd]];
+
+  TfLiteTensor& output_encoded =
+      context->tensors[node->outputs->data[kOutputEncodedInd]];
+  TfLiteTensor& output_positions =
+      context->tensors[node->outputs->data[kOutputPositionInd]];
+  output_encoded.type = kTfLiteInt32;
+  output_positions.type = kTfLiteInt32;
+  output_lengths.type = kTfLiteInt32;
+
+  TF_LITE_ENSURE_OK(context,
+                    context->ResizeTensor(context, &output_lengths,
+                                          CreateIntArray({kEncoderBatchSize})));
+
+  // Check that there are enough outputs for attributes.
+  const int num_output_attrs = node->outputs->size - kOutputAttrInd;
+  TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttrInd,
+                    num_output_attrs);
+
+  // Copy attribute types from input to output tensors.
+  for (int i = 0; i < num_output_attrs; ++i) {
+    TfLiteTensor& input =
+        context->tensors[node->inputs->data[kInputAttrInd + i]];
+    TfLiteTensor& output =
+        context->tensors[node->outputs->data[kOutputAttrInd + i]];
+    output.type = input.type;
+  }
+
+  const TfLiteTensor& output_length =
+      context->tensors[node->inputs->data[kMaxLengthInd]];
+
+  if (tflite::IsConstantTensor(&output_length)) {
+    return ResizeOutputTensors(context, node, output_length.data.i64[0]);
+  } else {
+    tflite::SetTensorToDynamic(&output_encoded);
+    tflite::SetTensorToDynamic(&output_positions);
+    for (int i = 0; i < num_output_attrs; ++i) {
+      TfLiteTensor& output_attr =
+          context->tensors[node->outputs->data[kOutputAttrInd + i]];
+      tflite::SetTensorToDynamic(&output_attr);
+    }
+  }
+
+  return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+  if (node->user_data == nullptr) {
+    return kTfLiteError;
+  }
+  auto text_encoder = reinterpret_cast<ByteEncoder*>(node->user_data);
+  const TfLiteTensor& input_text =
+      context->tensors[node->inputs->data[kInputTextInd]];
+  const int num_strings_in_tensor = tflite::GetStringCount(&input_text);
+  const int num_strings =
+      context->tensors[node->inputs->data[kTextLengthInd]].data.i32[0];
+
+  // Check that the number of strings is not bigger than the input tensor size.
+  TF_LITE_ENSURE(context, num_strings_in_tensor >= num_strings);
+
+  TfLiteTensor& output_encoded =
+      context->tensors[node->outputs->data[kOutputEncodedInd]];
+  if (tflite::IsDynamicTensor(&output_encoded)) {
+    const TfLiteTensor& output_length =
+        context->tensors[node->inputs->data[kMaxLengthInd]];
+    TF_LITE_ENSURE_OK(
+        context, ResizeOutputTensors(context, node, output_length.data.i64[0]));
+  }
+  TfLiteTensor& output_positions =
+      context->tensors[node->outputs->data[kOutputPositionInd]];
+
+  std::vector<int> encoded_total;
+  std::vector<int> encoded_positions;
+  std::vector<int> encoded_offsets;
+  encoded_offsets.reserve(num_strings);
+  const int max_output_length = output_encoded.dims->data[1];
+  const int max_encoded_position = max_output_length;
+
+  for (int i = 0; i < num_strings; ++i) {
+    const auto& strref = tflite::GetString(&input_text, i);
+    std::vector<int64_t> encoded;
+    text_encoder->Encode(
+        libtextclassifier3::StringPiece(strref.str, strref.len), &encoded);
+    encoded_total.insert(encoded_total.end(), encoded.begin(), encoded.end());
+    encoded_offsets.push_back(encoded_total.size());
+    for (int i = 0; i < encoded.size(); ++i) {
+      encoded_positions.push_back(std::min(i, max_encoded_position - 1));
+    }
+  }
+
+  // Copy encoding to output tensor.
+  const int start_offset =
+      std::max(0, static_cast<int>(encoded_total.size()) - max_output_length);
+  int output_offset = 0;
+  int32_t* output_buffer = output_encoded.data.i32;
+  int32_t* output_positions_buffer = output_positions.data.i32;
+  for (int i = start_offset; i < encoded_total.size(); ++i, ++output_offset) {
+    output_buffer[output_offset] = encoded_total[i];
+    output_positions_buffer[output_offset] = encoded_positions[i];
+  }
+
+  // Save output encoded length.
+  TfLiteTensor& output_lengths =
+      context->tensors[node->outputs->data[kOutputLengthsInd]];
+  output_lengths.data.i32[0] = output_offset;
+
+  // Do padding.
+  for (; output_offset < max_output_length; ++output_offset) {
+    output_buffer[output_offset] = 0;
+    output_positions_buffer[output_offset] = 0;
+  }
+
+  // Process attributes, all checks of sizes and types are done in Prepare.
+  const int num_output_attrs = node->outputs->size - kOutputAttrInd;
+  TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttrInd,
+                    num_output_attrs);
+  for (int i = 0; i < num_output_attrs; ++i) {
+    TfLiteStatus attr_status = CopyValuesToTensorAndPadOrTruncate(
+        context->tensors[node->inputs->data[kInputAttrInd + i]],
+        encoded_offsets, start_offset, context,
+        &context->tensors[node->outputs->data[kOutputAttrInd + i]]);
+    if (attr_status != kTfLiteOk) {
+      return attr_status;
+    }
+  }
+
+  return kTfLiteOk;
+}
+
+}  // namespace
+}  // namespace libtextclassifier3
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_TEXT_ENCODER3S() {
+  static TfLiteRegistration registration = {
+      libtextclassifier3::Initialize, libtextclassifier3::Free,
+      libtextclassifier3::Prepare, libtextclassifier3::Eval};
+  return &registration;
+}
+
+}  // namespace custom
+}  // namespace ops
+}  // namespace tflite
diff --git a/native/utils/tflite/text_encoder3s.h b/native/utils/tflite/text_encoder3s.h
new file mode 100644
index 0000000..50e1e64
--- /dev/null
+++ b/native/utils/tflite/text_encoder3s.h
@@ -0,0 +1,35 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// An encoder that produces positional and attributes encodings for a
+// transformer style model based on byte segmentation of text.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_TEXT_ENCODER3S_H_
+#define LIBTEXTCLASSIFIER_UTILS_TFLITE_TEXT_ENCODER3S_H_
+
+#include "tensorflow/lite/context.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_TEXT_ENCODER3S();
+
+}  // namespace custom
+}  // namespace ops
+}  // namespace tflite
+
+#endif  // LIBTEXTCLASSIFIER_UTILS_TFLITE_TEXT_ENCODER3S_H_
diff --git a/native/utils/tokenfree/byte_encoder.cc b/native/utils/tokenfree/byte_encoder.cc
new file mode 100644
index 0000000..c79d3a2
--- /dev/null
+++ b/native/utils/tokenfree/byte_encoder.cc
@@ -0,0 +1,42 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/tokenfree/byte_encoder.h"
+
+#include <vector>
+namespace libtextclassifier3 {
+
+bool ByteEncoder::Encode(StringPiece input_text,
+                         std::vector<int64_t>* encoded_text) const {
+  const int len = input_text.size();
+  if (len <= 0) {
+    *encoded_text = {};
+    return true;
+  }
+
+  int size = input_text.size();
+  encoded_text->resize(size);
+
+  const auto& text = input_text.ToString();
+  for (int i = 0; i < size; i++) {
+    int64_t encoding = static_cast<int64_t>(text[i]);
+    (*encoded_text)[i] = encoding;
+  }
+
+  return true;
+}
+
+}  // namespace libtextclassifier3
diff --git a/native/utils/tokenfree/byte_encoder.h b/native/utils/tokenfree/byte_encoder.h
new file mode 100644
index 0000000..1a495ec
--- /dev/null
+++ b/native/utils/tokenfree/byte_encoder.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TOKENFREE_BYTE_ENCODER_H_
+#define LIBTEXTCLASSIFIER_UTILS_TOKENFREE_BYTE_ENCODER_H_
+
+#include <vector>
+
+#include "utils/base/logging.h"
+#include "utils/container/string-set.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+// Encoder to segment/tokenize strings into bytes
+class ByteEncoder {
+ public:
+  bool Encode(StringPiece input_text, std::vector<int64_t>* encoded_text) const;
+  ByteEncoder() {}
+};
+
+}  // namespace libtextclassifier3
+
+#endif  // LIBTEXTCLASSIFIER_UTILS_TOKENFREE_BYTE_ENCODER_H_
diff --git a/native/utils/tokenfree/byte_encoder_test.cc b/native/utils/tokenfree/byte_encoder_test.cc
new file mode 100644
index 0000000..d4d119e
--- /dev/null
+++ b/native/utils/tokenfree/byte_encoder_test.cc
@@ -0,0 +1,51 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/tokenfree/byte_encoder.h"
+
+#include <memory>
+#include <vector>
+
+#include "utils/base/integral_types.h"
+#include "utils/container/sorted-strings-table.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using testing::ElementsAre;
+
+TEST(EncoderTest, SimpleTokenization) {
+  const ByteEncoder encoder;
+  {
+    std::vector<int64_t> encoded_text;
+    EXPECT_TRUE(encoder.Encode("hellothere", &encoded_text));
+    EXPECT_THAT(encoded_text,
+                ElementsAre(104, 101, 108, 108, 111, 116, 104, 101, 114, 101));
+  }
+}
+
+TEST(EncoderTest, SimpleTokenization2) {
+  const ByteEncoder encoder;
+  {
+    std::vector<int64_t> encoded_text;
+    EXPECT_TRUE(encoder.Encode("Hello", &encoded_text));
+    EXPECT_THAT(encoded_text, ElementsAre(72, 101, 108, 108, 111));
+  }
+}
+}  // namespace
+}  // namespace libtextclassifier3
diff --git a/native/utils/tokenizer.cc b/native/utils/tokenizer.cc
index 071141c..7038517 100644
--- a/native/utils/tokenizer.cc
+++ b/native/utils/tokenizer.cc
@@ -43,11 +43,12 @@
     codepoint_ranges_.emplace_back(range->UnPack());
   }
 
-  std::sort(codepoint_ranges_.begin(), codepoint_ranges_.end(),
-            [](const std::unique_ptr<const TokenizationCodepointRangeT>& a,
-               const std::unique_ptr<const TokenizationCodepointRangeT>& b) {
-              return a->start < b->start;
-            });
+  std::stable_sort(
+      codepoint_ranges_.begin(), codepoint_ranges_.end(),
+      [](const std::unique_ptr<const TokenizationCodepointRangeT>& a,
+         const std::unique_ptr<const TokenizationCodepointRangeT>& b) {
+        return a->start < b->start;
+      });
 
   SortCodepointRanges(internal_tokenizer_codepoint_ranges,
                       &internal_tokenizer_codepoint_ranges_);
diff --git a/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsLogSessionTest.java b/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsLogSessionTest.java
index bc30fcf..f539ba7 100644
--- a/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsLogSessionTest.java
+++ b/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsLogSessionTest.java
@@ -37,15 +37,20 @@
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
 
 @LargeTest
 @RunWith(AndroidJUnit4.class)
 public class SmartSuggestionsLogSessionTest {
+
+  @Rule public final MockitoRule mocks = MockitoJUnit.rule();
+
   private static final String RESULT_ID = "resultId";
   private static final String REPLY = "reply";
   private static final float SCORE = 0.5f;
@@ -55,7 +60,6 @@
 
   @Before
   public void setup() {
-    MockitoAnnotations.initMocks(this);
 
     session =
         new SmartSuggestionsLogSession(