RESTRICT AUTOMERGE TextClassifier cross-user vulnerability in direct-reply

Sys UI runs on user 0. This can lead to the TextClassifier (TC)
running for the wrong user. Consequencies are user A can launch apps
in user B via the TC's predicted actions and selected text being
unintentionally shared from user A to an app running in user B.

This fix ensures that the correct user id is passed and verified for
every TC request going across process boundaries (i.e. via SystemTC).
- Sys UI sets the appropriate user id in the TextView
- TextClassificationManager (TCM) system service is constructed using
  a context generated from this user id
- SystemTC sets this user id before querying the TCMService
- TCMService validates the user id before forwarding the request to
  the TCService belonging to that user id.

Bug: 136483597
Bug: 123232892
Test: atest android.view.textclassifier
      atest android.widget.TextViewActivityTest
      (manual) See I2fdffd8eb4221782cb1f34d2ddbe41dd3d36595c

Change-Id: Ibe68bc9e257521de97cbb014176b2b8ba23547d1
(cherry picked from commit 34e380cdd64230db81a5754b7b6e2654509af180)
diff --git a/core/java/android/view/textclassifier/ActionsModelParamsSupplier.java b/core/java/android/view/textclassifier/ActionsModelParamsSupplier.java
index 6b90588..3164567 100644
--- a/core/java/android/view/textclassifier/ActionsModelParamsSupplier.java
+++ b/core/java/android/view/textclassifier/ActionsModelParamsSupplier.java
@@ -60,7 +60,9 @@
     private boolean mParsed = true;
 
     public ActionsModelParamsSupplier(Context context, @Nullable Runnable onChangedListener) {
-        mAppContext = Preconditions.checkNotNull(context).getApplicationContext();
+        final Context appContext = Preconditions.checkNotNull(context).getApplicationContext();
+        // Some contexts don't have an app context.
+        mAppContext = appContext != null ? appContext : context;
         mOnChangedListener = onChangedListener == null ? () -> {} : onChangedListener;
         mSettingsObserver = new SettingsObserver(mAppContext, () -> {
             synchronized (mLock) {
diff --git a/core/java/android/view/textclassifier/ConversationActions.java b/core/java/android/view/textclassifier/ConversationActions.java
index aeb99b8..7c527ba 100644
--- a/core/java/android/view/textclassifier/ConversationActions.java
+++ b/core/java/android/view/textclassifier/ConversationActions.java
@@ -21,10 +21,12 @@
 import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.annotation.StringDef;
+import android.annotation.UserIdInt;
 import android.app.Person;
 import android.os.Bundle;
 import android.os.Parcel;
 import android.os.Parcelable;
+import android.os.UserHandle;
 import android.text.SpannedString;
 
 import com.android.internal.annotations.VisibleForTesting;
@@ -316,6 +318,8 @@
         private final List<String> mHints;
         @Nullable
         private String mCallingPackageName;
+        @UserIdInt
+        private int mUserId = UserHandle.USER_NULL;
         @NonNull
         private Bundle mExtras;
 
@@ -340,6 +344,7 @@
             List<String> hints = new ArrayList<>();
             in.readStringList(hints);
             String callingPackageName = in.readString();
+            int userId = in.readInt();
             Bundle extras = in.readBundle();
             Request request = new Request(
                     conversation,
@@ -348,6 +353,7 @@
                     hints,
                     extras);
             request.setCallingPackageName(callingPackageName);
+            request.setUserId(userId);
             return request;
         }
 
@@ -358,6 +364,7 @@
             parcel.writeInt(mMaxSuggestions);
             parcel.writeStringList(mHints);
             parcel.writeString(mCallingPackageName);
+            parcel.writeInt(mUserId);
             parcel.writeBundle(mExtras);
         }
 
@@ -428,6 +435,24 @@
         }
 
         /**
+         * Sets the id of the user that sent this request.
+         * <p>
+         * Package-private for SystemTextClassifier's use.
+         */
+        void setUserId(@UserIdInt int userId) {
+            mUserId = userId;
+        }
+
+        /**
+         * Returns the id of the user that sent this request.
+         * @hide
+         */
+        @UserIdInt
+        public int getUserId() {
+            return mUserId;
+        }
+
+        /**
          * Returns the extended data related to this request.
          *
          * <p><b>NOTE: </b>Do not modify this bundle.
diff --git a/core/java/android/view/textclassifier/SelectionEvent.java b/core/java/android/view/textclassifier/SelectionEvent.java
index 9ae0c65..5c81664 100644
--- a/core/java/android/view/textclassifier/SelectionEvent.java
+++ b/core/java/android/view/textclassifier/SelectionEvent.java
@@ -19,8 +19,10 @@
 import android.annotation.IntDef;
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.annotation.UserIdInt;
 import android.os.Parcel;
 import android.os.Parcelable;
+import android.os.UserHandle;
 import android.view.textclassifier.TextClassifier.EntityType;
 import android.view.textclassifier.TextClassifier.WidgetType;
 
@@ -127,6 +129,7 @@
     private String mWidgetType = TextClassifier.WIDGET_TYPE_UNKNOWN;
     private @InvocationMethod int mInvocationMethod;
     @Nullable private String mWidgetVersion;
+    private @UserIdInt int mUserId = UserHandle.USER_NULL;
     @Nullable private String mResultId;
     private long mEventTime;
     private long mDurationSinceSessionStart;
@@ -158,6 +161,7 @@
         mEntityType = in.readString();
         mWidgetVersion = in.readInt() > 0 ? in.readString() : null;
         mPackageName = in.readString();
+        mUserId = in.readInt();
         mWidgetType = in.readString();
         mInvocationMethod = in.readInt();
         mResultId = in.readString();
@@ -184,6 +188,7 @@
             dest.writeString(mWidgetVersion);
         }
         dest.writeString(mPackageName);
+        dest.writeInt(mUserId);
         dest.writeString(mWidgetType);
         dest.writeInt(mInvocationMethod);
         dest.writeString(mResultId);
@@ -401,6 +406,24 @@
     }
 
     /**
+     * Sets the id of this event's user.
+     * <p>
+     * Package-private for SystemTextClassifier's use.
+     */
+    void setUserId(@UserIdInt int userId) {
+        mUserId = userId;
+    }
+
+    /**
+     * Returns the id of this event's user.
+     * @hide
+     */
+    @UserIdInt
+    public int getUserId() {
+        return mUserId;
+    }
+
+    /**
      * Returns the type of widget that was involved in triggering this event.
      */
     @WidgetType
@@ -426,6 +449,7 @@
         mPackageName = context.getPackageName();
         mWidgetType = context.getWidgetType();
         mWidgetVersion = context.getWidgetVersion();
+        mUserId = context.getUserId();
     }
 
     /**
@@ -612,7 +636,7 @@
     @Override
     public int hashCode() {
         return Objects.hash(mAbsoluteStart, mAbsoluteEnd, mEventType, mEntityType,
-                mWidgetVersion, mPackageName, mWidgetType, mInvocationMethod, mResultId,
+                mWidgetVersion, mPackageName, mUserId, mWidgetType, mInvocationMethod, mResultId,
                 mEventTime, mDurationSinceSessionStart, mDurationSincePreviousEvent,
                 mEventIndex, mSessionId, mStart, mEnd, mSmartStart, mSmartEnd);
     }
@@ -633,6 +657,7 @@
                 && Objects.equals(mEntityType, other.mEntityType)
                 && Objects.equals(mWidgetVersion, other.mWidgetVersion)
                 && Objects.equals(mPackageName, other.mPackageName)
+                && mUserId == other.mUserId
                 && Objects.equals(mWidgetType, other.mWidgetType)
                 && mInvocationMethod == other.mInvocationMethod
                 && Objects.equals(mResultId, other.mResultId)
@@ -652,12 +677,12 @@
         return String.format(Locale.US,
                 "SelectionEvent {absoluteStart=%d, absoluteEnd=%d, eventType=%d, entityType=%s, "
                         + "widgetVersion=%s, packageName=%s, widgetType=%s, invocationMethod=%s, "
-                        + "resultId=%s, eventTime=%d, durationSinceSessionStart=%d, "
+                        + "userId=%d, resultId=%s, eventTime=%d, durationSinceSessionStart=%d, "
                         + "durationSincePreviousEvent=%d, eventIndex=%d,"
                         + "sessionId=%s, start=%d, end=%d, smartStart=%d, smartEnd=%d}",
                 mAbsoluteStart, mAbsoluteEnd, mEventType, mEntityType,
                 mWidgetVersion, mPackageName, mWidgetType, mInvocationMethod,
-                mResultId, mEventTime, mDurationSinceSessionStart,
+                mUserId, mResultId, mEventTime, mDurationSinceSessionStart,
                 mDurationSincePreviousEvent, mEventIndex,
                 mSessionId, mStart, mEnd, mSmartStart, mSmartEnd);
     }
diff --git a/core/java/android/view/textclassifier/SystemTextClassifier.java b/core/java/android/view/textclassifier/SystemTextClassifier.java
index 8f8766e..a97c330 100644
--- a/core/java/android/view/textclassifier/SystemTextClassifier.java
+++ b/core/java/android/view/textclassifier/SystemTextClassifier.java
@@ -18,6 +18,7 @@
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.annotation.UserIdInt;
 import android.annotation.WorkerThread;
 import android.content.Context;
 import android.os.Bundle;
@@ -50,6 +51,10 @@
     private final TextClassificationConstants mSettings;
     private final TextClassifier mFallback;
     private final String mPackageName;
+    // NOTE: Always set this before sending a request to the manager service otherwise the manager
+    // service will throw a remote exception.
+    @UserIdInt
+    private final int mUserId;
     private TextClassificationSessionId mSessionId;
 
     public SystemTextClassifier(Context context, TextClassificationConstants settings)
@@ -60,6 +65,7 @@
         mFallback = context.getSystemService(TextClassificationManager.class)
                 .getTextClassifier(TextClassifier.LOCAL);
         mPackageName = Preconditions.checkNotNull(context.getOpPackageName());
+        mUserId = context.getUserId();
     }
 
     /**
@@ -72,6 +78,7 @@
         Utils.checkMainThread();
         try {
             request.setCallingPackageName(mPackageName);
+            request.setUserId(mUserId);
             final BlockingCallback<TextSelection> callback =
                     new BlockingCallback<>("textselection");
             mManagerService.onSuggestSelection(mSessionId, request, callback);
@@ -95,6 +102,7 @@
         Utils.checkMainThread();
         try {
             request.setCallingPackageName(mPackageName);
+            request.setUserId(mUserId);
             final BlockingCallback<TextClassification> callback =
                     new BlockingCallback<>("textclassification");
             mManagerService.onClassifyText(mSessionId, request, callback);
@@ -123,6 +131,7 @@
 
         try {
             request.setCallingPackageName(mPackageName);
+            request.setUserId(mUserId);
             final BlockingCallback<TextLinks> callback =
                     new BlockingCallback<>("textlinks");
             mManagerService.onGenerateLinks(mSessionId, request, callback);
@@ -142,6 +151,7 @@
         Utils.checkMainThread();
 
         try {
+            event.setUserId(mUserId);
             mManagerService.onSelectionEvent(mSessionId, event);
         } catch (RemoteException e) {
             Log.e(LOG_TAG, "Error reporting selection event.", e);
@@ -154,6 +164,12 @@
         Utils.checkMainThread();
 
         try {
+            final TextClassificationContext tcContext = event.getEventContext() == null
+                    ? new TextClassificationContext.Builder(mPackageName, WIDGET_TYPE_UNKNOWN)
+                            .build()
+                    : event.getEventContext();
+            tcContext.setUserId(mUserId);
+            event.setEventContext(tcContext);
             mManagerService.onTextClassifierEvent(mSessionId, event);
         } catch (RemoteException e) {
             Log.e(LOG_TAG, "Error reporting textclassifier event.", e);
@@ -167,6 +183,7 @@
 
         try {
             request.setCallingPackageName(mPackageName);
+            request.setUserId(mUserId);
             final BlockingCallback<TextLanguage> callback =
                     new BlockingCallback<>("textlanguage");
             mManagerService.onDetectLanguage(mSessionId, request, callback);
@@ -187,6 +204,7 @@
 
         try {
             request.setCallingPackageName(mPackageName);
+            request.setUserId(mUserId);
             final BlockingCallback<ConversationActions> callback =
                     new BlockingCallback<>("conversation-actions");
             mManagerService.onSuggestConversationActions(mSessionId, request, callback);
@@ -228,6 +246,7 @@
         printWriter.printPair("mFallback", mFallback);
         printWriter.printPair("mPackageName", mPackageName);
         printWriter.printPair("mSessionId", mSessionId);
+        printWriter.printPair("mUserId", mUserId);
         printWriter.decreaseIndent();
         printWriter.println();
     }
@@ -243,6 +262,7 @@
             @NonNull TextClassificationSessionId sessionId) {
         mSessionId = Preconditions.checkNotNull(sessionId);
         try {
+            classificationContext.setUserId(mUserId);
             mManagerService.onCreateTextClassificationSession(classificationContext, mSessionId);
         } catch (RemoteException e) {
             Log.e(LOG_TAG, "Error starting a new classification session.", e);
diff --git a/core/java/android/view/textclassifier/TextClassification.java b/core/java/android/view/textclassifier/TextClassification.java
index 6321051..975f3ba 100644
--- a/core/java/android/view/textclassifier/TextClassification.java
+++ b/core/java/android/view/textclassifier/TextClassification.java
@@ -21,6 +21,7 @@
 import android.annotation.IntRange;
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.annotation.UserIdInt;
 import android.app.PendingIntent;
 import android.app.RemoteAction;
 import android.content.Context;
@@ -35,6 +36,7 @@
 import android.os.LocaleList;
 import android.os.Parcel;
 import android.os.Parcelable;
+import android.os.UserHandle;
 import android.text.SpannedString;
 import android.util.ArrayMap;
 import android.view.View.OnClickListener;
@@ -551,6 +553,8 @@
         @Nullable private final ZonedDateTime mReferenceTime;
         @NonNull private final Bundle mExtras;
         @Nullable private String mCallingPackageName;
+        @UserIdInt
+        private int mUserId = UserHandle.USER_NULL;
 
         private Request(
                 CharSequence text,
@@ -631,6 +635,24 @@
         }
 
         /**
+         * Sets the id of the user that sent this request.
+         * <p>
+         * Package-private for SystemTextClassifier's use.
+         */
+        void setUserId(@UserIdInt int userId) {
+            mUserId = userId;
+        }
+
+        /**
+         * Returns the id of the user that sent this request.
+         * @hide
+         */
+        @UserIdInt
+        public int getUserId() {
+            return mUserId;
+        }
+
+        /**
          * Returns the extended data.
          *
          * <p><b>NOTE: </b>Do not modify this bundle.
@@ -730,6 +752,7 @@
             dest.writeParcelable(mDefaultLocales, flags);
             dest.writeString(mReferenceTime == null ? null : mReferenceTime.toString());
             dest.writeString(mCallingPackageName);
+            dest.writeInt(mUserId);
             dest.writeBundle(mExtras);
         }
 
@@ -742,11 +765,13 @@
             final ZonedDateTime referenceTime = referenceTimeString == null
                     ? null : ZonedDateTime.parse(referenceTimeString);
             final String callingPackageName = in.readString();
+            final int userId = in.readInt();
             final Bundle extras = in.readBundle();
 
             final Request request = new Request(text, startIndex, endIndex,
                     defaultLocales, referenceTime, extras);
             request.setCallingPackageName(callingPackageName);
+            request.setUserId(userId);
             return request;
         }
 
diff --git a/core/java/android/view/textclassifier/TextClassificationContext.java b/core/java/android/view/textclassifier/TextClassificationContext.java
index 3bf8e9b..db07685 100644
--- a/core/java/android/view/textclassifier/TextClassificationContext.java
+++ b/core/java/android/view/textclassifier/TextClassificationContext.java
@@ -18,8 +18,10 @@
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.annotation.UserIdInt;
 import android.os.Parcel;
 import android.os.Parcelable;
+import android.os.UserHandle;
 import android.view.textclassifier.TextClassifier.WidgetType;
 
 import com.android.internal.util.Preconditions;
@@ -35,6 +37,8 @@
     private final String mPackageName;
     private final String mWidgetType;
     @Nullable private final String mWidgetVersion;
+    @UserIdInt
+    private int mUserId = UserHandle.USER_NULL;
 
     private TextClassificationContext(
             String packageName,
@@ -54,6 +58,24 @@
     }
 
     /**
+     * Sets the id of this context's user.
+     * <p>
+     * Package-private for SystemTextClassifier's use.
+     */
+    void setUserId(@UserIdInt int userId) {
+        mUserId = userId;
+    }
+
+    /**
+     * Returns the id of this context's user.
+     * @hide
+     */
+    @UserIdInt
+    public int getUserId() {
+        return mUserId;
+    }
+
+    /**
      * Returns the widget type for this classification context.
      */
     @NonNull
@@ -75,8 +97,8 @@
     @Override
     public String toString() {
         return String.format(Locale.US, "TextClassificationContext{"
-                + "packageName=%s, widgetType=%s, widgetVersion=%s}",
-                mPackageName, mWidgetType, mWidgetVersion);
+                + "packageName=%s, widgetType=%s, widgetVersion=%s, userId=%d}",
+                mPackageName, mWidgetType, mWidgetVersion, mUserId);
     }
 
     /**
@@ -133,12 +155,14 @@
         parcel.writeString(mPackageName);
         parcel.writeString(mWidgetType);
         parcel.writeString(mWidgetVersion);
+        parcel.writeInt(mUserId);
     }
 
     private TextClassificationContext(Parcel in) {
         mPackageName = in.readString();
         mWidgetType = in.readString();
         mWidgetVersion = in.readString();
+        mUserId = in.readInt();
     }
 
     public static final @android.annotation.NonNull Parcelable.Creator<TextClassificationContext> CREATOR =
diff --git a/core/java/android/view/textclassifier/TextClassifierEvent.java b/core/java/android/view/textclassifier/TextClassifierEvent.java
index 57da829..a041296 100644
--- a/core/java/android/view/textclassifier/TextClassifierEvent.java
+++ b/core/java/android/view/textclassifier/TextClassifierEvent.java
@@ -139,7 +139,7 @@
     @Nullable
     private final String[] mEntityTypes;
     @Nullable
-    private final TextClassificationContext mEventContext;
+    private TextClassificationContext mEventContext;
     @Nullable
     private final String mResultId;
     private final int mEventIndex;
@@ -289,6 +289,15 @@
     }
 
     /**
+     * Sets the event context.
+     * <p>
+     * Package-private for SystemTextClassifier's use.
+     */
+    void setEventContext(@Nullable TextClassificationContext eventContext) {
+        mEventContext = eventContext;
+    }
+
+    /**
      * Returns the id of the text classifier result related to this event.
      */
     @Nullable
diff --git a/core/java/android/view/textclassifier/TextLanguage.java b/core/java/android/view/textclassifier/TextLanguage.java
index 6c75ffb..3e3dc72 100644
--- a/core/java/android/view/textclassifier/TextLanguage.java
+++ b/core/java/android/view/textclassifier/TextLanguage.java
@@ -20,10 +20,12 @@
 import android.annotation.IntRange;
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.annotation.UserIdInt;
 import android.icu.util.ULocale;
 import android.os.Bundle;
 import android.os.Parcel;
 import android.os.Parcelable;
+import android.os.UserHandle;
 import android.util.ArrayMap;
 
 import com.android.internal.annotations.VisibleForTesting;
@@ -226,6 +228,8 @@
         private final CharSequence mText;
         private final Bundle mExtra;
         @Nullable private String mCallingPackageName;
+        @UserIdInt
+        private int mUserId = UserHandle.USER_NULL;
 
         private Request(CharSequence text, Bundle bundle) {
             mText = text;
@@ -260,6 +264,24 @@
         }
 
         /**
+         * Sets the id of the user that sent this request.
+         * <p>
+         * Package-private for SystemTextClassifier's use.
+         */
+        void setUserId(@UserIdInt int userId) {
+            mUserId = userId;
+        }
+
+        /**
+         * Returns the id of the user that sent this request.
+         * @hide
+         */
+        @UserIdInt
+        public int getUserId() {
+            return mUserId;
+        }
+
+        /**
          * Returns a bundle containing non-structured extra information about this request.
          *
          * <p><b>NOTE: </b>Do not modify this bundle.
@@ -278,16 +300,19 @@
         public void writeToParcel(Parcel dest, int flags) {
             dest.writeCharSequence(mText);
             dest.writeString(mCallingPackageName);
+            dest.writeInt(mUserId);
             dest.writeBundle(mExtra);
         }
 
         private static Request readFromParcel(Parcel in) {
             final CharSequence text = in.readCharSequence();
             final String callingPackageName = in.readString();
+            final int userId = in.readInt();
             final Bundle extra = in.readBundle();
 
             final Request request = new Request(text, extra);
             request.setCallingPackageName(callingPackageName);
+            request.setUserId(userId);
             return request;
         }
 
diff --git a/core/java/android/view/textclassifier/TextLinks.java b/core/java/android/view/textclassifier/TextLinks.java
index f3e0dc1..d7ac524 100644
--- a/core/java/android/view/textclassifier/TextLinks.java
+++ b/core/java/android/view/textclassifier/TextLinks.java
@@ -20,11 +20,13 @@
 import android.annotation.IntDef;
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.annotation.UserIdInt;
 import android.content.Context;
 import android.os.Bundle;
 import android.os.LocaleList;
 import android.os.Parcel;
 import android.os.Parcelable;
+import android.os.UserHandle;
 import android.text.Spannable;
 import android.text.method.MovementMethod;
 import android.text.style.ClickableSpan;
@@ -339,6 +341,8 @@
         private final boolean mLegacyFallback;
         @Nullable private String mCallingPackageName;
         private final Bundle mExtras;
+        @UserIdInt
+        private int mUserId = UserHandle.USER_NULL;
 
         private Request(
                 CharSequence text,
@@ -410,6 +414,24 @@
         }
 
         /**
+         * Sets the id of the user that sent this request.
+         * <p>
+         * Package-private for SystemTextClassifier's use.
+         */
+        void setUserId(@UserIdInt int userId) {
+            mUserId = userId;
+        }
+
+        /**
+         * Returns the id of the user that sent this request.
+         * @hide
+         */
+        @UserIdInt
+        public int getUserId() {
+            return mUserId;
+        }
+
+        /**
          * Returns the extended data.
          *
          * <p><b>NOTE: </b>Do not modify this bundle.
@@ -509,6 +531,7 @@
             dest.writeParcelable(mDefaultLocales, flags);
             dest.writeParcelable(mEntityConfig, flags);
             dest.writeString(mCallingPackageName);
+            dest.writeInt(mUserId);
             dest.writeBundle(mExtras);
         }
 
@@ -517,11 +540,13 @@
             final LocaleList defaultLocales = in.readParcelable(null);
             final EntityConfig entityConfig = in.readParcelable(null);
             final String callingPackageName = in.readString();
+            final int userId = in.readInt();
             final Bundle extras = in.readBundle();
 
             final Request request = new Request(text, defaultLocales, entityConfig,
                     /* legacyFallback= */ true, extras);
             request.setCallingPackageName(callingPackageName);
+            request.setUserId(userId);
             return request;
         }
 
diff --git a/core/java/android/view/textclassifier/TextSelection.java b/core/java/android/view/textclassifier/TextSelection.java
index 75c27bd..94e0bc3 100644
--- a/core/java/android/view/textclassifier/TextSelection.java
+++ b/core/java/android/view/textclassifier/TextSelection.java
@@ -20,10 +20,12 @@
 import android.annotation.IntRange;
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.annotation.UserIdInt;
 import android.os.Bundle;
 import android.os.LocaleList;
 import android.os.Parcel;
 import android.os.Parcelable;
+import android.os.UserHandle;
 import android.text.SpannedString;
 import android.util.ArrayMap;
 import android.view.textclassifier.TextClassifier.EntityType;
@@ -211,6 +213,8 @@
         private final boolean mDarkLaunchAllowed;
         private final Bundle mExtras;
         @Nullable private String mCallingPackageName;
+        @UserIdInt
+        private int mUserId = UserHandle.USER_NULL;
 
         private Request(
                 CharSequence text,
@@ -292,6 +296,24 @@
         }
 
         /**
+         * Sets the id of the user that sent this request.
+         * <p>
+         * Package-private for SystemTextClassifier's use.
+         */
+        void setUserId(@UserIdInt int userId) {
+            mUserId = userId;
+        }
+
+        /**
+         * Returns the id of the user that sent this request.
+         * @hide
+         */
+        @UserIdInt
+        public int getUserId() {
+            return mUserId;
+        }
+
+        /**
          * Returns the extended data.
          *
          * <p><b>NOTE: </b>Do not modify this bundle.
@@ -394,6 +416,7 @@
             dest.writeInt(mEndIndex);
             dest.writeParcelable(mDefaultLocales, flags);
             dest.writeString(mCallingPackageName);
+            dest.writeInt(mUserId);
             dest.writeBundle(mExtras);
         }
 
@@ -403,11 +426,13 @@
             final int endIndex = in.readInt();
             final LocaleList defaultLocales = in.readParcelable(null);
             final String callingPackageName = in.readString();
+            final int userId = in.readInt();
             final Bundle extras = in.readBundle();
 
             final Request request = new Request(text, startIndex, endIndex, defaultLocales,
                     /* darkLaunchAllowed= */ false, extras);
             request.setCallingPackageName(callingPackageName);
+            request.setUserId(userId);
             return request;
         }
 
diff --git a/core/java/android/widget/TextView.java b/core/java/android/widget/TextView.java
index 0918c5f..944bdcd 100644
--- a/core/java/android/widget/TextView.java
+++ b/core/java/android/widget/TextView.java
@@ -11283,6 +11283,12 @@
     }
 
     @Nullable
+    final TextClassificationManager getTextClassificationManagerForUser() {
+        return getServiceManagerForUser(
+                getContext().getPackageName(), TextClassificationManager.class);
+    }
+
+    @Nullable
     final <T> T getServiceManagerForUser(String packageName, Class<T> managerClazz) {
         if (mTextOperationUser == null) {
             return getContext().getSystemService(managerClazz);
@@ -12383,8 +12389,7 @@
     @NonNull
     public TextClassifier getTextClassifier() {
         if (mTextClassifier == null) {
-            final TextClassificationManager tcm =
-                    mContext.getSystemService(TextClassificationManager.class);
+            final TextClassificationManager tcm = getTextClassificationManagerForUser();
             if (tcm != null) {
                 return tcm.getTextClassifier();
             }
@@ -12400,8 +12405,7 @@
     @NonNull
     TextClassifier getTextClassificationSession() {
         if (mTextClassificationSession == null || mTextClassificationSession.isDestroyed()) {
-            final TextClassificationManager tcm =
-                    mContext.getSystemService(TextClassificationManager.class);
+            final TextClassificationManager tcm = getTextClassificationManagerForUser();
             if (tcm != null) {
                 final String widgetType;
                 if (isTextEditable()) {
diff --git a/services/core/java/com/android/server/textclassifier/TextClassificationManagerService.java b/services/core/java/com/android/server/textclassifier/TextClassificationManagerService.java
index 5f00148..89a5305 100644
--- a/services/core/java/com/android/server/textclassifier/TextClassificationManagerService.java
+++ b/services/core/java/com/android/server/textclassifier/TextClassificationManagerService.java
@@ -30,6 +30,7 @@
 import android.service.textclassifier.ITextClassifierCallback;
 import android.service.textclassifier.ITextClassifierService;
 import android.service.textclassifier.TextClassifierService;
+import android.util.ArrayMap;
 import android.util.Slog;
 import android.util.SparseArray;
 import android.view.textclassifier.ConversationActions;
@@ -54,6 +55,7 @@
 import java.io.FileDescriptor;
 import java.io.PrintWriter;
 import java.util.ArrayDeque;
+import java.util.Map;
 import java.util.Queue;
 
 /**
@@ -119,6 +121,8 @@
     private final Object mLock;
     @GuardedBy("mLock")
     final SparseArray<UserState> mUserStates = new SparseArray<>();
+    @GuardedBy("mLock")
+    private final Map<TextClassificationSessionId, Integer> mSessionUserIds = new ArrayMap<>();
 
     private TextClassificationManagerService(Context context) {
         mContext = Preconditions.checkNotNull(context);
@@ -127,15 +131,16 @@
 
     @Override
     public void onSuggestSelection(
-            TextClassificationSessionId sessionId,
+            @Nullable TextClassificationSessionId sessionId,
             TextSelection.Request request, ITextClassifierCallback callback)
             throws RemoteException {
         Preconditions.checkNotNull(request);
         Preconditions.checkNotNull(callback);
-        validateInput(mContext, request.getCallingPackageName());
+        final int userId = request.getUserId();
+        validateInput(mContext, request.getCallingPackageName(), userId);
 
         synchronized (mLock) {
-            UserState userState = getCallingUserStateLocked();
+            UserState userState = getUserStateLocked(userId);
             if (!userState.bindLocked()) {
                 callback.onFailure();
             } else if (userState.isBoundLocked()) {
@@ -150,15 +155,16 @@
 
     @Override
     public void onClassifyText(
-            TextClassificationSessionId sessionId,
+            @Nullable TextClassificationSessionId sessionId,
             TextClassification.Request request, ITextClassifierCallback callback)
             throws RemoteException {
         Preconditions.checkNotNull(request);
         Preconditions.checkNotNull(callback);
-        validateInput(mContext, request.getCallingPackageName());
+        final int userId = request.getUserId();
+        validateInput(mContext, request.getCallingPackageName(), userId);
 
         synchronized (mLock) {
-            UserState userState = getCallingUserStateLocked();
+            UserState userState = getUserStateLocked(userId);
             if (!userState.bindLocked()) {
                 callback.onFailure();
             } else if (userState.isBoundLocked()) {
@@ -173,15 +179,16 @@
 
     @Override
     public void onGenerateLinks(
-            TextClassificationSessionId sessionId,
+            @Nullable TextClassificationSessionId sessionId,
             TextLinks.Request request, ITextClassifierCallback callback)
             throws RemoteException {
         Preconditions.checkNotNull(request);
         Preconditions.checkNotNull(callback);
-        validateInput(mContext, request.getCallingPackageName());
+        final int userId = request.getUserId();
+        validateInput(mContext, request.getCallingPackageName(), userId);
 
         synchronized (mLock) {
-            UserState userState = getCallingUserStateLocked();
+            UserState userState = getUserStateLocked(userId);
             if (!userState.bindLocked()) {
                 callback.onFailure();
             } else if (userState.isBoundLocked()) {
@@ -196,12 +203,14 @@
 
     @Override
     public void onSelectionEvent(
-            TextClassificationSessionId sessionId, SelectionEvent event) throws RemoteException {
+            @Nullable TextClassificationSessionId sessionId, SelectionEvent event)
+            throws RemoteException {
         Preconditions.checkNotNull(event);
-        validateInput(mContext, event.getPackageName());
+        final int userId = event.getUserId();
+        validateInput(mContext, event.getPackageName(), userId);
 
         synchronized (mLock) {
-            UserState userState = getCallingUserStateLocked();
+            UserState userState = getUserStateLocked(userId);
             if (userState.isBoundLocked()) {
                 userState.mService.onSelectionEvent(sessionId, event);
             } else {
@@ -213,16 +222,19 @@
     }
     @Override
     public void onTextClassifierEvent(
-            TextClassificationSessionId sessionId,
+            @Nullable TextClassificationSessionId sessionId,
             TextClassifierEvent event) throws RemoteException {
         Preconditions.checkNotNull(event);
         final String packageName = event.getEventContext() == null
                 ? null
                 : event.getEventContext().getPackageName();
-        validateInput(mContext, packageName);
+        final int userId = event.getEventContext() == null
+                ? UserHandle.getCallingUserId()
+                : event.getEventContext().getUserId();
+        validateInput(mContext, packageName, userId);
 
         synchronized (mLock) {
-            UserState userState = getCallingUserStateLocked();
+            UserState userState = getUserStateLocked(userId);
             if (userState.isBoundLocked()) {
                 userState.mService.onTextClassifierEvent(sessionId, event);
             } else {
@@ -235,15 +247,16 @@
 
     @Override
     public void onDetectLanguage(
-            TextClassificationSessionId sessionId,
+            @Nullable TextClassificationSessionId sessionId,
             TextLanguage.Request request,
             ITextClassifierCallback callback) throws RemoteException {
         Preconditions.checkNotNull(request);
         Preconditions.checkNotNull(callback);
-        validateInput(mContext, request.getCallingPackageName());
+        final int userId = request.getUserId();
+        validateInput(mContext, request.getCallingPackageName(), userId);
 
         synchronized (mLock) {
-            UserState userState = getCallingUserStateLocked();
+            UserState userState = getUserStateLocked(userId);
             if (!userState.bindLocked()) {
                 callback.onFailure();
             } else if (userState.isBoundLocked()) {
@@ -258,15 +271,16 @@
 
     @Override
     public void onSuggestConversationActions(
-            TextClassificationSessionId sessionId,
+            @Nullable TextClassificationSessionId sessionId,
             ConversationActions.Request request,
             ITextClassifierCallback callback) throws RemoteException {
         Preconditions.checkNotNull(request);
         Preconditions.checkNotNull(callback);
-        validateInput(mContext, request.getCallingPackageName());
+        final int userId = request.getUserId();
+        validateInput(mContext, request.getCallingPackageName(), userId);
 
         synchronized (mLock) {
-            UserState userState = getCallingUserStateLocked();
+            UserState userState = getUserStateLocked(userId);
             if (!userState.bindLocked()) {
                 callback.onFailure();
             } else if (userState.isBoundLocked()) {
@@ -285,13 +299,15 @@
             throws RemoteException {
         Preconditions.checkNotNull(sessionId);
         Preconditions.checkNotNull(classificationContext);
-        validateInput(mContext, classificationContext.getPackageName());
+        final int userId = classificationContext.getUserId();
+        validateInput(mContext, classificationContext.getPackageName(), userId);
 
         synchronized (mLock) {
-            UserState userState = getCallingUserStateLocked();
+            UserState userState = getUserStateLocked(userId);
             if (userState.isBoundLocked()) {
                 userState.mService.onCreateTextClassificationSession(
                         classificationContext, sessionId);
+                mSessionUserIds.put(sessionId, userId);
             } else {
                 userState.mPendingRequests.add(new PendingRequest(
                         () -> onCreateTextClassificationSession(classificationContext, sessionId),
@@ -306,9 +322,15 @@
         Preconditions.checkNotNull(sessionId);
 
         synchronized (mLock) {
-            UserState userState = getCallingUserStateLocked();
+            final int userId = mSessionUserIds.containsKey(sessionId)
+                    ? mSessionUserIds.get(sessionId)
+                    : UserHandle.getCallingUserId();
+            validateInput(mContext, null /* packageName */, userId);
+
+            UserState userState = getUserStateLocked(userId);
             if (userState.isBoundLocked()) {
                 userState.mService.onDestroyTextClassificationSession(sessionId);
+                mSessionUserIds.remove(sessionId);
             } else {
                 userState.mPendingRequests.add(new PendingRequest(
                         () -> onDestroyTextClassificationSession(sessionId),
@@ -318,11 +340,6 @@
     }
 
     @GuardedBy("mLock")
-    private UserState getCallingUserStateLocked() {
-        return getUserStateLocked(UserHandle.getCallingUserId());
-    }
-
-    @GuardedBy("mLock")
     private UserState getUserStateLocked(int userId) {
         UserState result = mUserStates.get(userId);
         if (result == null) {
@@ -356,6 +373,7 @@
                     pw.decreaseIndent();
                 }
             }
+            pw.println("Number of active sessions: " + mSessionUserIds.size());
         }
     }
 
@@ -420,20 +438,32 @@
                 e -> Slog.d(LOG_TAG, "Error " + opDesc + ": " + e.getMessage()));
     }
 
-    private static void validateInput(Context context, @Nullable String packageName)
+    private static void validateInput(
+            Context context, @Nullable String packageName, @UserIdInt int userId)
             throws RemoteException {
-        if (packageName == null) return;
 
         try {
-            final int packageUid = context.getPackageManager()
-                    .getPackageUidAsUser(packageName, UserHandle.getCallingUserId());
-            final int callingUid = Binder.getCallingUid();
-            Preconditions.checkArgument(callingUid == packageUid
-                    // Trust the system process:
-                    || callingUid == android.os.Process.SYSTEM_UID);
+            if (packageName != null) {
+                final int packageUid = context.getPackageManager()
+                        .getPackageUidAsUser(packageName, UserHandle.getCallingUserId());
+                final int callingUid = Binder.getCallingUid();
+                Preconditions.checkArgument(callingUid == packageUid
+                        // Trust the system process:
+                        || callingUid == android.os.Process.SYSTEM_UID,
+                        "Invalid package name. Package=" + packageName
+                                + ", CallingUid=" + callingUid);
+            }
+
+            Preconditions.checkArgument(userId != UserHandle.USER_NULL, "Null userId");
+            final int callingUserId = UserHandle.getCallingUserId();
+            if (callingUserId != userId) {
+                context.enforceCallingOrSelfPermission(
+                        android.Manifest.permission.INTERACT_ACROSS_USERS_FULL,
+                        "Invalid userId. UserId=" + userId + ", CallingUserId=" + callingUserId);
+            }
         } catch (Exception e) {
-            throw new RemoteException(
-                    String.format("Invalid package: name=%s, error=%s", packageName, e));
+            throw new RemoteException("Invalid request: " + e.getMessage(), e,
+                    /* enableSuppression */ true, /* writableStackTrace */ true);
         }
     }