Snap for 9254005 from 1075b1e4e39ab4af90deb3758e5631943c07d47e to mainline-uwb-release

Change-Id: Ibbba9bec7558f1f9e724304545c11e5813dcfd56
diff --git a/native/actions/actions-suggestions.cc b/native/actions/actions-suggestions.cc
index 9f9a8d4..eeeb508 100644
--- a/native/actions/actions-suggestions.cc
+++ b/native/actions/actions-suggestions.cc
@@ -21,6 +21,8 @@
 #include <vector>
 
 #include "utils/base/statusor.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/random/random.h"
 
 #if !defined(TC3_DISABLE_LUA)
 #include "actions/lua-actions.h"
@@ -42,6 +44,7 @@
 #include "utils/strings/utf8.h"
 #include "utils/utf8/unicodetext.h"
 #include "absl/container/flat_hash_set.h"
+#include "absl/random/distributions.h"
 #include "tensorflow/lite/string_util.h"
 
 namespace libtextclassifier3 {
@@ -813,6 +816,8 @@
     const tflite::Interpreter* interpreter, int suggestion_index,
     int score_index, const std::string& type, float priority_score,
     const absl::flat_hash_set<std::string>& blocklist,
+    const absl::flat_hash_map<std::string, std::vector<std::string>>&
+        concept_mappings,
     ActionsSuggestionsResponse* response) const {
   const std::vector<tflite::StringRef> replies =
       model_executor_->Output<tflite::StringRef>(suggestion_index, interpreter);
@@ -831,6 +836,12 @@
     if (blocklist.contains(response_text)) {
       continue;
     }
+    if (concept_mappings.contains(response_text)) {
+      const int candidates_size = concept_mappings.at(response_text).size();
+      const int candidate_index = absl::Uniform<int>(
+          absl::IntervalOpenOpen, bit_gen_, 0, candidates_size);
+      response_text = concept_mappings.at(response_text)[candidate_index];
+    }
 
     response->actions.push_back({response_text, type, score, priority_score});
   }
@@ -918,11 +929,11 @@
   if (!response->output_filtered_min_triggering_score &&
       model_->tflite_model_spec()->output_replies() >= 0) {
     absl::flat_hash_set<std::string> empty_blocklist;
-    PopulateTextReplies(interpreter,
-                        model_->tflite_model_spec()->output_replies(),
-                        model_->tflite_model_spec()->output_replies_scores(),
-                        model_->smart_reply_action_type()->str(),
-                        /* priority_score */ 0.0, empty_blocklist, response);
+    PopulateTextReplies(
+        interpreter, model_->tflite_model_spec()->output_replies(),
+        model_->tflite_model_spec()->output_replies_scores(),
+        model_->smart_reply_action_type()->str(),
+        /* priority_score */ 0.0, empty_blocklist, {}, response);
   }
 
   // Read actions suggestions.
@@ -961,6 +972,8 @@
       const int suggestions_scores_index =
           metadata->output_suggestions_scores();
       absl::flat_hash_set<std::string> response_text_blocklist;
+      absl::flat_hash_map<std::string, std::vector<std::string>>
+          concept_mappings;
       switch (metadata->prediction_type()) {
         case PredictionType_NEXT_MESSAGE_PREDICTION:
           if (!task_spec || task_spec->type()->size() == 0) {
@@ -973,13 +986,22 @@
                 response_text_blocklist.insert(val->str());
               }
             }
+            if (task_spec->concept_mappings()) {
+              for (const auto& concept : *task_spec->concept_mappings()) {
+                std::vector<std::string> candidates;
+                for (const auto& candidate : *concept->candidates()) {
+                  candidates.push_back(candidate->str());
+                }
+                concept_mappings[concept->concept_name()->str()] = candidates;
+              }
+            }
           }
           PopulateTextReplies(
               interpreter, suggestions_index, suggestions_scores_index,
               task_spec ? task_spec->type()->str()
                         : model_->smart_reply_action_type()->str(),
               task_spec ? task_spec->priority_score() : 0.0,
-              response_text_blocklist, response);
+              response_text_blocklist, concept_mappings, response);
           break;
         case PredictionType_INTENT_TRIGGERING:
           PopulateIntentTriggering(interpreter, suggestions_index,
diff --git a/native/actions/actions-suggestions.h b/native/actions/actions-suggestions.h
index 87f55fb..c3d58e4 100644
--- a/native/actions/actions-suggestions.h
+++ b/native/actions/actions-suggestions.h
@@ -43,7 +43,9 @@
 #include "utils/utf8/unilib.h"
 #include "utils/variant.h"
 #include "utils/zlib/zlib.h"
+#include "absl/container/flat_hash_map.h"
 #include "absl/container/flat_hash_set.h"
+#include "absl/random/random.h"
 
 namespace libtextclassifier3 {
 
@@ -175,11 +177,13 @@
   void FillSuggestionFromSpecWithEntityData(const ActionSuggestionSpec* spec,
                                             ActionSuggestion* suggestion) const;
 
-  void PopulateTextReplies(const tflite::Interpreter* interpreter,
-                           int suggestion_index, int score_index,
-                           const std::string& type, float priority_score,
-                           const absl::flat_hash_set<std::string>& blocklist,
-                           ActionsSuggestionsResponse* response) const;
+  void PopulateTextReplies(
+      const tflite::Interpreter* interpreter, int suggestion_index,
+      int score_index, const std::string& type, float priority_score,
+      const absl::flat_hash_set<std::string>& blocklist,
+      const absl::flat_hash_map<std::string, std::vector<std::string>>&
+          concept_mappings,
+      ActionsSuggestionsResponse* response) const;
 
   void PopulateIntentTriggering(const tflite::Interpreter* interpreter,
                                 int suggestion_index, int score_index,
@@ -273,6 +277,9 @@
   // Conversation intent detection model for additional actions.
   std::unique_ptr<const ConversationIntentDetection>
       conversation_intent_detection_;
+
+  // Used for randomly selecting candidates.
+  mutable absl::BitGen bit_gen_;
 };
 
 // Interprets the buffer as a Model flatbuffer and returns it for reading.
diff --git a/native/actions/actions-suggestions_test.cc b/native/actions/actions-suggestions_test.cc
index b51ebc7..65f9796 100644
--- a/native/actions/actions-suggestions_test.cc
+++ b/native/actions/actions-suggestions_test.cc
@@ -61,6 +61,8 @@
     "actions_suggestions_test.multi_task_sr_p13n.model";
 constexpr char kMultiTaskSrEmojiModelFileName[] =
     "actions_suggestions_test.multi_task_sr_emoji.model";
+constexpr char kMultiTaskSrEmojiConceptModelFileName[] =
+    "actions_suggestions_test.multi_task_sr_emoji_concept.model";
 constexpr char kSensitiveTFliteModelFileName[] =
     "actions_suggestions_test.sensitive_tflite.model";
 constexpr char kLiveRelayTFLiteModelFileName[] =
@@ -1835,6 +1837,25 @@
   EXPECT_EQ(response.actions[2].type, "text_reply");
 }
 
+TEST_F(ActionsSuggestionsTest, MultiTaskSrEmojiModelUsesConcepts) {
+  std::unique_ptr<ActionsSuggestions> actions_suggestions =
+      LoadTestModel(kMultiTaskSrEmojiConceptModelFileName);
+
+  const ActionsSuggestionsResponse response =
+      actions_suggestions->SuggestActions(
+          {{{/*user_id=*/1, "i am tired",
+             /*reference_time_ms_utc=*/0,
+             /*reference_timezone=*/"Europe/Zurich",
+             /*annotations=*/{},
+             /*locales=*/"en"}}});
+  std::vector<std::string> sigh_emojis = {"😔", "😞"};
+
+  EXPECT_TRUE(std::find(sigh_emojis.begin(), sigh_emojis.end(),
+                        response.actions[0].response_text) !=
+              sigh_emojis.end());
+  EXPECT_EQ(response.actions[0].type, "emoji_reply");
+}
+
 TEST_F(ActionsSuggestionsTest, LiveRelayModel) {
   std::unique_ptr<ActionsSuggestions> actions_suggestions =
       LoadTestModel(kLiveRelayTFLiteModelFileName);
diff --git a/native/actions/actions_model.fbs b/native/actions/actions_model.fbs
index 0d8c7ad..70f9104 100644
--- a/native/actions/actions_model.fbs
+++ b/native/actions/actions_model.fbs
@@ -312,6 +312,15 @@
   min_reply_score_threshold:float = 0;
 }
 
+// This proto handles model outputs that are concepts, such as emoji concept
+// suggestion models. Each concept maps to a list of candidates. One of
+// the candidates is chosen randomly as the final suggestion.
+namespace libtextclassifier3;
+table ActionConceptToSuggestion {
+  concept_name:string (shared);
+  candidates:[string];
+}
+
 namespace libtextclassifier3;
 table ActionSuggestionSpec {
   // Type of the action suggestion.
@@ -331,6 +340,10 @@
 
   entity_data:ActionsEntityData;
   response_text_blocklist:[string];
+
+  // If provided, map the response as concept to one of the corresponding
+  // candidates.
+  concept_mappings:[ActionConceptToSuggestion];
 }
 
 // Options to specify triggering behaviour per action class.
diff --git a/native/actions/test_data/actions_suggestions_grammar_test.model b/native/actions/test_data/actions_suggestions_grammar_test.model
index a44bfe6..6d7bdb0 100644
--- a/native/actions/test_data/actions_suggestions_grammar_test.model
+++ b/native/actions/test_data/actions_suggestions_grammar_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.model b/native/actions/test_data/actions_suggestions_test.model
index d262953..88f62eb 100644
--- a/native/actions/test_data/actions_suggestions_test.model
+++ b/native/actions/test_data/actions_suggestions_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
index 96fd5ef..40a2409 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model
index b77d2b5..effb2cb 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji_concept.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji_concept.model
new file mode 100644
index 0000000..18333d6
--- /dev/null
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji_concept.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
index ad9b684..e41ab39 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model
index 7fa095a..5314b43 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model b/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model
index 33bb389..a633742 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model b/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model
index 11b7524..6685d26 100644
--- a/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model
+++ b/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model
Binary files differ
diff --git a/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java b/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java
index 9429b29..28f947b 100644
--- a/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java
+++ b/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java
@@ -96,16 +96,11 @@
           oldSession.destroy();
         }
       };
-  private final TextClassificationContext textClassificationContext;
 
   public SmartSuggestionsHelper(Context context, SmartSuggestionsConfig config) {
     this.context = context;
     textClassificationManager = this.context.getSystemService(TextClassificationManager.class);
     this.config = config;
-    this.textClassificationContext =
-        new TextClassificationContext.Builder(
-                context.getPackageName(), TextClassifier.WIDGET_TYPE_NOTIFICATION)
-            .build();
   }
 
   /**
@@ -170,7 +165,10 @@
           } else {
             SmartSuggestionsLogSession session =
                 new SmartSuggestionsLogSession(
-                    resultId, repliesScore, textClassifier, textClassificationContext);
+                    resultId,
+                    repliesScore,
+                    textClassifier,
+                    getTextClassificationContext(statusBarNotification));
             session.onSuggestionsGenerated(conversationActions);
 
             // Store the session if we expect more logging from it, destroy it otherwise.
@@ -302,7 +300,11 @@
             .setTypeConfig(typeConfigBuilder.build())
             .build();
 
-    TextClassifier textClassifier = createTextClassificationSession();
+    TextClassifier textClassifier =
+        textClassificationManager.createTextClassificationSession(
+            getTextClassificationContext(statusBarNotification));
+    onTextClassificationSessionCreated();
+
     return new SuggestConversationActionsResult(
         Optional.of(textClassifier), textClassifier.suggestConversationActions(request));
   }
@@ -477,8 +479,13 @@
   }
 
   @VisibleForTesting
-  TextClassifier createTextClassificationSession() {
-    return textClassificationManager.createTextClassificationSession(textClassificationContext);
+  void onTextClassificationSessionCreated() {}
+
+  private static TextClassificationContext getTextClassificationContext(
+      StatusBarNotification statusBarNotification) {
+    return new TextClassificationContext.Builder(
+            statusBarNotification.getPackageName(), TextClassifier.WIDGET_TYPE_NOTIFICATION)
+        .build();
   }
 
   private static boolean arePersonsEqual(Person left, Person right) {
diff --git a/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java b/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java
index 84cf4fb..9354819 100644
--- a/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java
+++ b/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java
@@ -86,9 +86,8 @@
     }
 
     @Override
-    TextClassifier createTextClassificationSession() {
+    void onTextClassificationSessionCreated() {
       numOfSessionsCreated += 1;
-      return super.createTextClassificationSession();
     }
 
     int getNumOfSessionsCreated() {
@@ -260,9 +259,11 @@
     assertThat(firstEvent.getEntityTypes())
         .asList()
         .containsExactly(ConversationAction.TYPE_TEXT_REPLY, ConversationAction.TYPE_OPEN_URL);
+    assertThat(firstEvent.getEventContext().getPackageName()).isEqualTo(PACKAGE_NAME);
     TextClassifierEvent secondEvent = textClassifierEvents.get(1);
     assertThat(secondEvent.getEventType()).isEqualTo(TextClassifierEvent.TYPE_SMART_ACTION);
     assertThat(secondEvent.getEntityTypes()[0]).isEqualTo(ConversationAction.TYPE_TEXT_REPLY);
+    assertThat(secondEvent.getEventContext().getPackageName()).isEqualTo(PACKAGE_NAME);
   }
 
   @Test