Correctly count nri uid request counts

Correctly count nri uid request counts in the per-app functionality in
connectivity currently used by set profile and set oem network
preference APIs. Previously, upon creation, nris would be created prior
to removing them. This would cause the uid request counts to
artificially increase and incorrectly throw an error if the request
count limit was hit even though in actuality an apps request count was
valid.

E.g., if there was an existing request for per-app functionality and
its owning app made a change to the per-app requests, it would double
count the existing requests. If the current count was say, one under the
limit, an error would be thrown even though it was being replaced which
should have resulted in no net change to the request count limit if
working correctly.

This patch will allow for the requests to be removed prior to creation
so that request counts are tabulated correctly.

Bug: 185849563
Bug: 183785319
Test: atest FrameworksNetTests
Change-Id: I13da0c81256cc02bea6aff2fe1ef99d6f6b0e764
diff --git a/services/core/java/com/android/server/ConnectivityService.java b/services/core/java/com/android/server/ConnectivityService.java
index 2216e89..6460053 100644
--- a/services/core/java/com/android/server/ConnectivityService.java
+++ b/services/core/java/com/android/server/ConnectivityService.java
@@ -319,7 +319,8 @@
     private static final int MAX_NETWORK_REQUESTS_PER_UID = 100;
 
     // The maximum number of network request allowed for system UIDs before an exception is thrown.
-    private static final int MAX_NETWORK_REQUESTS_PER_SYSTEM_UID = 250;
+    @VisibleForTesting
+    static final int MAX_NETWORK_REQUESTS_PER_SYSTEM_UID = 250;
 
     @VisibleForTesting
     protected int mLingerDelayMs;  // Can't be final, or test subclass constructors can't change it.
@@ -336,7 +337,8 @@
     protected final PermissionMonitor mPermissionMonitor;
 
     private final PerUidCounter mNetworkRequestCounter;
-    private final PerUidCounter mSystemNetworkRequestCounter;
+    @VisibleForTesting
+    final PerUidCounter mSystemNetworkRequestCounter;
 
     private volatile boolean mLockdownEnabled;
 
@@ -1047,8 +1049,9 @@
         private final int mMaxCountPerUid;
 
         // Map from UID to number of NetworkRequests that UID has filed.
+        @VisibleForTesting
         @GuardedBy("mUidToNetworkRequestCount")
-        private final SparseIntArray mUidToNetworkRequestCount = new SparseIntArray();
+        final SparseIntArray mUidToNetworkRequestCount = new SparseIntArray();
 
         /**
          * Constructor
@@ -1072,15 +1075,20 @@
          */
         public void incrementCountOrThrow(final int uid) {
             synchronized (mUidToNetworkRequestCount) {
-                final int networkRequests = mUidToNetworkRequestCount.get(uid, 0) + 1;
-                if (networkRequests >= mMaxCountPerUid) {
-                    throw new ServiceSpecificException(
-                            ConnectivityManager.Errors.TOO_MANY_REQUESTS);
-                }
-                mUidToNetworkRequestCount.put(uid, networkRequests);
+                incrementCountOrThrow(uid, 1 /* numToIncrement */);
             }
         }
 
+        private void incrementCountOrThrow(final int uid, final int numToIncrement) {
+            final int newRequestCount =
+                    mUidToNetworkRequestCount.get(uid, 0) + numToIncrement;
+            if (newRequestCount >= mMaxCountPerUid) {
+                throw new ServiceSpecificException(
+                        ConnectivityManager.Errors.TOO_MANY_REQUESTS);
+            }
+            mUidToNetworkRequestCount.put(uid, newRequestCount);
+        }
+
         /**
          * Decrements the request count of the given uid.
          *
@@ -1088,16 +1096,50 @@
          */
         public void decrementCount(final int uid) {
             synchronized (mUidToNetworkRequestCount) {
-                final int requests = mUidToNetworkRequestCount.get(uid, 0);
-                if (requests < 1) {
-                    logwtf("BUG: too small request count " + requests + " for UID " + uid);
-                } else if (requests == 1) {
-                    mUidToNetworkRequestCount.delete(uid);
-                } else {
-                    mUidToNetworkRequestCount.put(uid, requests - 1);
-                }
+                decrementCount(uid, 1 /* numToDecrement */);
             }
         }
+
+        private void decrementCount(final int uid, final int numToDecrement) {
+            final int newRequestCount =
+                    mUidToNetworkRequestCount.get(uid, 0) - numToDecrement;
+            if (newRequestCount < 0) {
+                logwtf("BUG: too small request count " + newRequestCount + " for UID " + uid);
+            } else if (newRequestCount == 0) {
+                mUidToNetworkRequestCount.delete(uid);
+            } else {
+                mUidToNetworkRequestCount.put(uid, newRequestCount);
+            }
+        }
+
+        /**
+         * Used to adjust the request counter for the per-app API flows. Directly adjusting the
+         * counter is not ideal however in the per-app flows, the nris can't be removed until they
+         * are used to create the new nris upon set. Therefore the request count limit can be
+         * artificially hit. This method is used as a workaround for this particular case so that
+         * the request counts are accounted for correctly.
+         * @param uid the uid to adjust counts for
+         * @param numOfNewRequests the new request count to account for
+         * @param r the runnable to execute
+         */
+        public void transact(final int uid, final int numOfNewRequests, @NonNull final Runnable r) {
+            // This should only be used on the handler thread as per all current and foreseen
+            // use-cases. ensureRunningOnConnectivityServiceThread() can't be used because there is
+            // no ref to the outer ConnectivityService.
+            synchronized (mUidToNetworkRequestCount) {
+                final int reqCountOverage = getCallingUidRequestCountOverage(uid, numOfNewRequests);
+                decrementCount(uid, reqCountOverage);
+                r.run();
+                incrementCountOrThrow(uid, reqCountOverage);
+            }
+        }
+
+        private int getCallingUidRequestCountOverage(final int uid, final int numOfNewRequests) {
+            final int newUidRequestCount = mUidToNetworkRequestCount.get(uid, 0)
+                    + numOfNewRequests;
+            return newUidRequestCount >= MAX_NETWORK_REQUESTS_PER_SYSTEM_UID
+                    ? newUidRequestCount - (MAX_NETWORK_REQUESTS_PER_SYSTEM_UID - 1) : 0;
+        }
     }
 
     /**
@@ -9569,9 +9611,13 @@
         validateNetworkCapabilitiesOfProfileNetworkPreference(preference.capabilities);
 
         mProfileNetworkPreferences = mProfileNetworkPreferences.plus(preference);
-        final ArraySet<NetworkRequestInfo> nris =
-                createNrisFromProfileNetworkPreferences(mProfileNetworkPreferences);
-        replaceDefaultNetworkRequestsForPreference(nris);
+        mSystemNetworkRequestCounter.transact(
+                mDeps.getCallingUid(), mProfileNetworkPreferences.preferences.size(),
+                () -> {
+                    final ArraySet<NetworkRequestInfo> nris =
+                            createNrisFromProfileNetworkPreferences(mProfileNetworkPreferences);
+                    replaceDefaultNetworkRequestsForPreference(nris);
+                });
         // Finally, rematch.
         rematchAllNetworksAndRequests();
 
@@ -9657,9 +9703,16 @@
         }
 
         mOemNetworkPreferencesLogs.log("UPDATE INITIATED: " + preference);
-        final ArraySet<NetworkRequestInfo> nris =
-                new OemNetworkRequestFactory().createNrisFromOemNetworkPreferences(preference);
-        replaceDefaultNetworkRequestsForPreference(nris);
+        final int uniquePreferenceCount = new ArraySet<>(
+                preference.getNetworkPreferences().values()).size();
+        mSystemNetworkRequestCounter.transact(
+                mDeps.getCallingUid(), uniquePreferenceCount,
+                () -> {
+                    final ArraySet<NetworkRequestInfo> nris =
+                            new OemNetworkRequestFactory()
+                                    .createNrisFromOemNetworkPreferences(preference);
+                    replaceDefaultNetworkRequestsForPreference(nris);
+                });
         mOemNetworkPreferences = preference;
 
         if (null != listener) {
@@ -9684,10 +9737,14 @@
         final ArraySet<NetworkRequestInfo> perAppCallbackRequestsToUpdate =
                 getPerAppCallbackRequestsToUpdate();
         final ArraySet<NetworkRequestInfo> nrisToRegister = new ArraySet<>(nris);
-        nrisToRegister.addAll(
-                createPerAppCallbackRequestsToRegister(perAppCallbackRequestsToUpdate));
-        handleRemoveNetworkRequests(perAppCallbackRequestsToUpdate);
-        handleRegisterNetworkRequests(nrisToRegister);
+        mSystemNetworkRequestCounter.transact(
+                mDeps.getCallingUid(), perAppCallbackRequestsToUpdate.size(),
+                () -> {
+                    nrisToRegister.addAll(
+                            createPerAppCallbackRequestsToRegister(perAppCallbackRequestsToUpdate));
+                    handleRemoveNetworkRequests(perAppCallbackRequestsToUpdate);
+                    handleRegisterNetworkRequests(nrisToRegister);
+                });
     }
 
     /**
diff --git a/tests/net/java/com/android/server/ConnectivityServiceTest.java b/tests/net/java/com/android/server/ConnectivityServiceTest.java
index 91811f6..c58e937 100644
--- a/tests/net/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/net/java/com/android/server/ConnectivityServiceTest.java
@@ -10246,12 +10246,15 @@
         return UidRange.createForUser(UserHandle.of(userId));
     }
 
-    private void mockGetApplicationInfo(@NonNull final String packageName, @NonNull final int uid)
-            throws Exception {
+    private void mockGetApplicationInfo(@NonNull final String packageName, @NonNull final int uid) {
         final ApplicationInfo applicationInfo = new ApplicationInfo();
         applicationInfo.uid = uid;
-        when(mPackageManager.getApplicationInfo(eq(packageName), anyInt()))
-                .thenReturn(applicationInfo);
+        try {
+            when(mPackageManager.getApplicationInfo(eq(packageName), anyInt()))
+                    .thenReturn(applicationInfo);
+        } catch (Exception e) {
+            fail(e.getMessage());
+        }
     }
 
     private void mockGetApplicationInfoThrowsNameNotFound(@NonNull final String packageName)
@@ -10272,8 +10275,7 @@
     }
 
     private OemNetworkPreferences createDefaultOemNetworkPreferences(
-            @OemNetworkPreferences.OemNetworkPreference final int preference)
-            throws Exception {
+            @OemNetworkPreferences.OemNetworkPreference final int preference) {
         // Arrange PackageManager mocks
         mockGetApplicationInfo(TEST_PACKAGE_NAME, TEST_PACKAGE_UID);
 
@@ -10750,11 +10752,13 @@
             mDone.complete(new Object());
         }
 
-        void expectOnComplete() throws Exception {
+        void expectOnComplete() {
             try {
                 mDone.get(TIMEOUT_MS, TimeUnit.MILLISECONDS);
             } catch (TimeoutException e) {
                 fail("Expected onComplete() not received after " + TIMEOUT_MS + " ms");
+            } catch (Exception e) {
+                fail(e.getMessage());
             }
         }
 
@@ -12369,4 +12373,72 @@
                 expected,
                 () -> mCm.registerNetworkCallback(getRequestWithSubIds(), new NetworkCallback()));
     }
+
+    /**
+     * Validate request counts are counted accurately on setProfileNetworkPreference on set/replace.
+     */
+    @Test
+    public void testProfileNetworkPrefCountsRequestsCorrectlyOnSet() throws Exception {
+        final UserHandle testHandle = setupEnterpriseNetwork();
+        testRequestCountLimits(() -> {
+            // Set initially to test the limit prior to having existing requests.
+            final TestOnCompleteListener listener = new TestOnCompleteListener();
+            mCm.setProfileNetworkPreference(testHandle, PROFILE_NETWORK_PREFERENCE_ENTERPRISE,
+                    Runnable::run, listener);
+            listener.expectOnComplete();
+
+            // re-set so as to test the limit as part of replacing existing requests.
+            mCm.setProfileNetworkPreference(testHandle, PROFILE_NETWORK_PREFERENCE_ENTERPRISE,
+                    Runnable::run, listener);
+            listener.expectOnComplete();
+        });
+    }
+
+    /**
+     * Validate request counts are counted accurately on setOemNetworkPreference on set/replace.
+     */
+    @Test
+    public void testSetOemNetworkPreferenceCountsRequestsCorrectlyOnSet() throws Exception {
+        mockHasSystemFeature(PackageManager.FEATURE_AUTOMOTIVE, true);
+        @OemNetworkPreferences.OemNetworkPreference final int networkPref =
+                OEM_NETWORK_PREFERENCE_OEM_PRIVATE_ONLY;
+        testRequestCountLimits(() -> {
+            // Set initially to test the limit prior to having existing requests.
+            final TestOemListenerCallback listener = new TestOemListenerCallback();
+            mService.setOemNetworkPreference(
+                    createDefaultOemNetworkPreferences(networkPref), listener);
+            listener.expectOnComplete();
+
+            // re-set so as to test the limit as part of replacing existing requests.
+            mService.setOemNetworkPreference(
+                    createDefaultOemNetworkPreferences(networkPref), listener);
+            listener.expectOnComplete();
+        });
+    }
+
+    private void testRequestCountLimits(@NonNull final Runnable r) throws Exception {
+        final ArraySet<TestNetworkCallback> callbacks = new ArraySet<>();
+        try {
+            final int requestCount = mService.mSystemNetworkRequestCounter
+                    .mUidToNetworkRequestCount.get(Process.myUid());
+            // The limit is hit when total requests <= limit.
+            final int maxCount =
+                    ConnectivityService.MAX_NETWORK_REQUESTS_PER_SYSTEM_UID - requestCount;
+            // Need permission so registerDefaultNetworkCallback uses mSystemNetworkRequestCounter
+            withPermission(NetworkStack.PERMISSION_MAINLINE_NETWORK_STACK, () -> {
+                for (int i = 1; i < maxCount - 1; i++) {
+                    final TestNetworkCallback cb = new TestNetworkCallback();
+                    mCm.registerDefaultNetworkCallback(cb);
+                    callbacks.add(cb);
+                }
+
+                // Code to run to check if it triggers a max request count limit error.
+                r.run();
+            });
+        } finally {
+            for (final TestNetworkCallback cb : callbacks) {
+                mCm.unregisterNetworkCallback(cb);
+            }
+        }
+    }
 }