Revert "Revert "Update VPN capabilities when its underlying network set is null.""

This reverts commit 6a050c7c50fa0838d7e29b8c6e244018044246db.

Reason for revert: Retargeted for June monthly release

Bug: 119129310

Change-Id: I9d543415c5707859cfa2a14a1a8ce5909aae7d11
Merged-In: Id0abc4d304bb096e92479a118168690ccce634ed
diff --git a/services/core/java/com/android/server/ConnectivityService.java b/services/core/java/com/android/server/ConnectivityService.java
index c9f9ab6..1c66c5a 100644
--- a/services/core/java/com/android/server/ConnectivityService.java
+++ b/services/core/java/com/android/server/ConnectivityService.java
@@ -867,7 +867,8 @@
 
         mPermissionMonitor = new PermissionMonitor(mContext, mNetd);
 
-        //set up the listener for user state for creating user VPNs
+        // Set up the listener for user state for creating user VPNs.
+        // Should run on mHandler to avoid any races.
         IntentFilter intentFilter = new IntentFilter();
         intentFilter.addAction(Intent.ACTION_USER_STARTED);
         intentFilter.addAction(Intent.ACTION_USER_STOPPED);
@@ -875,7 +876,11 @@
         intentFilter.addAction(Intent.ACTION_USER_REMOVED);
         intentFilter.addAction(Intent.ACTION_USER_UNLOCKED);
         mContext.registerReceiverAsUser(
-                mUserIntentReceiver, UserHandle.ALL, intentFilter, null, null);
+                mUserIntentReceiver,
+                UserHandle.ALL,
+                intentFilter,
+                null /* broadcastPermission */,
+                mHandler);
         mContext.registerReceiverAsUser(mUserPresentReceiver, UserHandle.SYSTEM,
                 new IntentFilter(Intent.ACTION_USER_PRESENT), null, null);
 
@@ -3815,17 +3820,27 @@
      * handler thread through their agent, this is asynchronous. When the capabilities objects
      * are computed they will be up-to-date as they are computed synchronously from here and
      * this is running on the ConnectivityService thread.
-     * TODO : Fix this and call updateCapabilities inline to remove out-of-order events.
      */
     private void updateAllVpnsCapabilities() {
+        Network defaultNetwork = getNetwork(getDefaultNetwork());
         synchronized (mVpns) {
             for (int i = 0; i < mVpns.size(); i++) {
                 final Vpn vpn = mVpns.valueAt(i);
-                vpn.updateCapabilities();
+                NetworkCapabilities nc = vpn.updateCapabilities(defaultNetwork);
+                updateVpnCapabilities(vpn, nc);
             }
         }
     }
 
+    private void updateVpnCapabilities(Vpn vpn, @Nullable NetworkCapabilities nc) {
+        ensureRunningOnConnectivityServiceThread();
+        NetworkAgentInfo vpnNai = getNetworkAgentInfoForNetId(vpn.getNetId());
+        if (vpnNai == null || nc == null) {
+            return;
+        }
+        updateCapabilities(vpnNai.getCurrentScore(), vpnNai, nc);
+    }
+
     @Override
     public boolean updateLockdownVpn() {
         if (Binder.getCallingUid() != Process.SYSTEM_UID) {
@@ -4132,21 +4147,27 @@
     }
 
     private void onUserAdded(int userId) {
+        Network defaultNetwork = getNetwork(getDefaultNetwork());
         synchronized (mVpns) {
             final int vpnsSize = mVpns.size();
             for (int i = 0; i < vpnsSize; i++) {
                 Vpn vpn = mVpns.valueAt(i);
                 vpn.onUserAdded(userId);
+                NetworkCapabilities nc = vpn.updateCapabilities(defaultNetwork);
+                updateVpnCapabilities(vpn, nc);
             }
         }
     }
 
     private void onUserRemoved(int userId) {
+        Network defaultNetwork = getNetwork(getDefaultNetwork());
         synchronized (mVpns) {
             final int vpnsSize = mVpns.size();
             for (int i = 0; i < vpnsSize; i++) {
                 Vpn vpn = mVpns.valueAt(i);
                 vpn.onUserRemoved(userId);
+                NetworkCapabilities nc = vpn.updateCapabilities(defaultNetwork);
+                updateVpnCapabilities(vpn, nc);
             }
         }
     }
@@ -4165,6 +4186,7 @@
     private BroadcastReceiver mUserIntentReceiver = new BroadcastReceiver() {
         @Override
         public void onReceive(Context context, Intent intent) {
+            ensureRunningOnConnectivityServiceThread();
             final String action = intent.getAction();
             final int userId = intent.getIntExtra(Intent.EXTRA_USER_HANDLE, UserHandle.USER_NULL);
             if (userId == UserHandle.USER_NULL) return;
@@ -4650,6 +4672,19 @@
         return getNetworkForRequest(mDefaultRequest.requestId);
     }
 
+    @Nullable
+    private Network getNetwork(@Nullable NetworkAgentInfo nai) {
+        return nai != null ? nai.network : null;
+    }
+
+    private void ensureRunningOnConnectivityServiceThread() {
+        if (mHandler.getLooper().getThread() != Thread.currentThread()) {
+            throw new IllegalStateException(
+                    "Not running on ConnectivityService thread: "
+                            + Thread.currentThread().getName());
+        }
+    }
+
     private boolean isDefaultNetwork(NetworkAgentInfo nai) {
         return nai == getDefaultNetwork();
     }
@@ -5197,6 +5232,8 @@
         updateTcpBufferSizes(newNetwork);
         mDnsManager.setDefaultDnsSystemProperties(newNetwork.linkProperties.getDnsServers());
         notifyIfacesChangedForNetworkStats();
+        // Fix up the NetworkCapabilities of any VPNs that don't specify underlying networks.
+        updateAllVpnsCapabilities();
     }
 
     private void processListenRequests(NetworkAgentInfo nai, boolean capabilitiesChanged) {
@@ -5630,6 +5667,10 @@
             // doing.
             updateSignalStrengthThresholds(networkAgent, "CONNECT", null);
 
+            if (networkAgent.isVPN()) {
+                updateAllVpnsCapabilities();
+            }
+
             // Consider network even though it is not yet validated.
             final long now = SystemClock.elapsedRealtime();
             rematchNetworkAndRequests(networkAgent, ReapUnvalidatedNetworks.REAP, now);
@@ -5823,7 +5864,11 @@
             success = mVpns.get(user).setUnderlyingNetworks(networks);
         }
         if (success) {
-            mHandler.post(() -> notifyIfacesChangedForNetworkStats());
+            mHandler.post(() -> {
+                // Update VPN's capabilities based on updated underlying network set.
+                updateAllVpnsCapabilities();
+                notifyIfacesChangedForNetworkStats();
+            });
         }
         return success;
     }
diff --git a/services/core/java/com/android/server/connectivity/Vpn.java b/services/core/java/com/android/server/connectivity/Vpn.java
index 2a80f0e..56510b7e 100644
--- a/services/core/java/com/android/server/connectivity/Vpn.java
+++ b/services/core/java/com/android/server/connectivity/Vpn.java
@@ -273,7 +273,7 @@
         mNetworkCapabilities = new NetworkCapabilities();
         mNetworkCapabilities.addTransportType(NetworkCapabilities.TRANSPORT_VPN);
         mNetworkCapabilities.removeCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN);
-        updateCapabilities();
+        updateCapabilities(null /* defaultNetwork */);
 
         loadAlwaysOnPackage();
     }
@@ -300,18 +300,39 @@
         updateAlwaysOnNotification(detailedState);
     }
 
-    public void updateCapabilities() {
-        final Network[] underlyingNetworks = (mConfig != null) ? mConfig.underlyingNetworks : null;
-        updateCapabilities(mContext.getSystemService(ConnectivityManager.class), underlyingNetworks,
+    /**
+     * Updates {@link #mNetworkCapabilities} based on current underlying networks and returns a
+     * defensive copy.
+     *
+     * <p>Does not propagate updated capabilities to apps.
+     *
+     * @param defaultNetwork underlying network for VPNs following platform's default
+     */
+    public synchronized NetworkCapabilities updateCapabilities(
+            @Nullable Network defaultNetwork) {
+        if (mConfig == null) {
+            // VPN is not running.
+            return null;
+        }
+
+        Network[] underlyingNetworks = mConfig.underlyingNetworks;
+        if (underlyingNetworks == null && defaultNetwork != null) {
+            // null underlying networks means to track the default.
+            underlyingNetworks = new Network[] { defaultNetwork };
+        }
+
+        applyUnderlyingCapabilities(
+                mContext.getSystemService(ConnectivityManager.class),
+                underlyingNetworks,
                 mNetworkCapabilities);
 
-        if (mNetworkAgent != null) {
-            mNetworkAgent.sendNetworkCapabilities(mNetworkCapabilities);
-        }
+        return new NetworkCapabilities(mNetworkCapabilities);
     }
 
     @VisibleForTesting
-    public static void updateCapabilities(ConnectivityManager cm, Network[] underlyingNetworks,
+    public static void applyUnderlyingCapabilities(
+            ConnectivityManager cm,
+            Network[] underlyingNetworks,
             NetworkCapabilities caps) {
         int[] transportTypes = new int[] { NetworkCapabilities.TRANSPORT_VPN };
         int downKbps = NetworkCapabilities.LINK_BANDWIDTH_UNSPECIFIED;
@@ -323,6 +344,7 @@
         boolean hadUnderlyingNetworks = false;
         if (null != underlyingNetworks) {
             for (Network underlying : underlyingNetworks) {
+                // TODO(b/124469351): Get capabilities directly from ConnectivityService instead.
                 final NetworkCapabilities underlyingCaps = cm.getNetworkCapabilities(underlying);
                 if (underlyingCaps == null) continue;
                 hadUnderlyingNetworks = true;
@@ -993,9 +1015,8 @@
     }
 
     /**
-     * Establish a VPN network and return the file descriptor of the VPN
-     * interface. This methods returns {@code null} if the application is
-     * revoked or not prepared.
+     * Establish a VPN network and return the file descriptor of the VPN interface. This methods
+     * returns {@code null} if the application is revoked or not prepared.
      *
      * @param config The parameters to configure the network.
      * @return The file descriptor of the VPN interface.
@@ -1242,6 +1263,11 @@
         return ranges;
     }
 
+    /**
+     * Updates UID ranges for this VPN and also updates its internal capabilities.
+     *
+     * <p>Should be called on primary ConnectivityService thread.
+     */
     public void onUserAdded(int userHandle) {
         // If the user is restricted tie them to the parent user's VPN
         UserInfo user = UserManager.get(mContext).getUserInfo(userHandle);
@@ -1252,8 +1278,9 @@
                     try {
                         addUserToRanges(existingRanges, userHandle, mConfig.allowedApplications,
                                 mConfig.disallowedApplications);
+                        // ConnectivityService will call {@link #updateCapabilities} and apply
+                        // those for VPN network.
                         mNetworkCapabilities.setUids(existingRanges);
-                        updateCapabilities();
                     } catch (Exception e) {
                         Log.wtf(TAG, "Failed to add restricted user to owner", e);
                     }
@@ -1263,6 +1290,11 @@
         }
     }
 
+    /**
+     * Updates UID ranges for this VPN and also updates its capabilities.
+     *
+     * <p>Should be called on primary ConnectivityService thread.
+     */
     public void onUserRemoved(int userHandle) {
         // clean up if restricted
         UserInfo user = UserManager.get(mContext).getUserInfo(userHandle);
@@ -1274,8 +1306,9 @@
                         final List<UidRange> removedRanges =
                             uidRangesForUser(userHandle, existingRanges);
                         existingRanges.removeAll(removedRanges);
+                        // ConnectivityService will call {@link #updateCapabilities} and
+                        // apply those for VPN network.
                         mNetworkCapabilities.setUids(existingRanges);
-                        updateCapabilities();
                     } catch (Exception e) {
                         Log.wtf(TAG, "Failed to remove restricted user to owner", e);
                     }
@@ -1483,6 +1516,12 @@
         return success;
     }
 
+    /**
+     * Updates underlying network set.
+     *
+     * <p>Note: Does not updates capabilities. Call {@link #updateCapabilities} from
+     * ConnectivityService thread to get updated capabilities.
+     */
     public synchronized boolean setUnderlyingNetworks(Network[] networks) {
         if (!isCallerEstablishedOwnerLocked()) {
             return false;
@@ -1499,7 +1538,6 @@
                 }
             }
         }
-        updateCapabilities();
         return true;
     }
 
diff --git a/tests/net/java/com/android/server/ConnectivityServiceTest.java b/tests/net/java/com/android/server/ConnectivityServiceTest.java
index c2c627d..0dcb21a 100644
--- a/tests/net/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/net/java/com/android/server/ConnectivityServiceTest.java
@@ -20,6 +20,7 @@
 import static android.net.ConnectivityManager.PRIVATE_DNS_MODE_OFF;
 import static android.net.ConnectivityManager.PRIVATE_DNS_MODE_OPPORTUNISTIC;
 import static android.net.ConnectivityManager.PRIVATE_DNS_MODE_PROVIDER_HOSTNAME;
+import static android.net.ConnectivityManager.NETID_UNSET;
 import static android.net.ConnectivityManager.TYPE_ETHERNET;
 import static android.net.ConnectivityManager.TYPE_MOBILE;
 import static android.net.ConnectivityManager.TYPE_MOBILE_FOTA;
@@ -775,11 +776,14 @@
 
         public void setUids(Set<UidRange> uids) {
             mNetworkCapabilities.setUids(uids);
-            updateCapabilities();
+            updateCapabilities(null /* defaultNetwork */);
         }
 
         @Override
         public int getNetId() {
+            if (mMockNetworkAgent == null) {
+                return NETID_UNSET;
+            }
             return mMockNetworkAgent.getNetwork().netId;
         }
 
@@ -800,12 +804,13 @@
         }
 
         @Override
-        public void updateCapabilities() {
-            if (!mConnected) return;
-            super.updateCapabilities();
-            // Because super.updateCapabilities will update the capabilities of the agent but not
-            // the mock agent, the mock agent needs to know about them.
+        public NetworkCapabilities updateCapabilities(Network defaultNetwork) {
+            if (!mConnected) return null;
+            super.updateCapabilities(defaultNetwork);
+            // Because super.updateCapabilities will update the capabilities of the agent but
+            // not the mock agent, the mock agent needs to know about them.
             copyCapabilitiesToNetworkAgent();
+            return new NetworkCapabilities(mNetworkCapabilities);
         }
 
         private void copyCapabilitiesToNetworkAgent() {
@@ -4218,6 +4223,7 @@
         mMockVpn.setUids(ranges);
         vpnNetworkAgent.connect(false);
         mMockVpn.connect();
+        mMockVpn.setUnderlyingNetworks(new Network[0]);
 
         genericNetworkCallback.expectAvailableCallbacksUnvalidated(vpnNetworkAgent);
         genericNotVpnNetworkCallback.assertNoCallback();
@@ -4250,6 +4256,7 @@
 
         ranges.add(new UidRange(uid, uid));
         mMockVpn.setUids(ranges);
+        vpnNetworkAgent.setUids(ranges);
 
         genericNetworkCallback.expectAvailableCallbacksValidated(vpnNetworkAgent);
         genericNotVpnNetworkCallback.assertNoCallback();
@@ -4283,12 +4290,11 @@
     }
 
     @Test
-    public void testVpnWithAndWithoutInternet() {
+    public void testVpnWithoutInternet() {
         final int uid = Process.myUid();
 
         final TestNetworkCallback defaultCallback = new TestNetworkCallback();
         mCm.registerDefaultNetworkCallback(defaultCallback);
-        defaultCallback.assertNoCallback();
 
         mWiFiNetworkAgent = new MockNetworkAgent(TRANSPORT_WIFI);
         mWiFiNetworkAgent.connect(true);
@@ -4310,11 +4316,30 @@
         vpnNetworkAgent.disconnect();
         defaultCallback.assertNoCallback();
 
-        vpnNetworkAgent = new MockNetworkAgent(TRANSPORT_VPN);
+        mCm.unregisterNetworkCallback(defaultCallback);
+    }
+
+    @Test
+    public void testVpnWithInternet() {
+        final int uid = Process.myUid();
+
+        final TestNetworkCallback defaultCallback = new TestNetworkCallback();
+        mCm.registerDefaultNetworkCallback(defaultCallback);
+
+        mWiFiNetworkAgent = new MockNetworkAgent(TRANSPORT_WIFI);
+        mWiFiNetworkAgent.connect(true);
+
+        defaultCallback.expectAvailableThenValidatedCallbacks(mWiFiNetworkAgent);
+        assertEquals(defaultCallback.getLastAvailableNetwork(), mCm.getActiveNetwork());
+
+        MockNetworkAgent vpnNetworkAgent = new MockNetworkAgent(TRANSPORT_VPN);
+        final ArraySet<UidRange> ranges = new ArraySet<>();
+        ranges.add(new UidRange(uid, uid));
         mMockVpn.setNetworkAgent(vpnNetworkAgent);
         mMockVpn.setUids(ranges);
         vpnNetworkAgent.connect(true /* validated */, true /* hasInternet */);
         mMockVpn.connect();
+
         defaultCallback.expectAvailableThenValidatedCallbacks(vpnNetworkAgent);
         assertEquals(defaultCallback.getLastAvailableNetwork(), mCm.getActiveNetwork());
 
@@ -4322,14 +4347,6 @@
         defaultCallback.expectCallback(CallbackState.LOST, vpnNetworkAgent);
         defaultCallback.expectAvailableCallbacksValidated(mWiFiNetworkAgent);
 
-        vpnNetworkAgent = new MockNetworkAgent(TRANSPORT_VPN);
-        ranges.clear();
-        mMockVpn.setNetworkAgent(vpnNetworkAgent);
-        mMockVpn.setUids(ranges);
-        vpnNetworkAgent.connect(false /* validated */, true /* hasInternet */);
-        mMockVpn.connect();
-        defaultCallback.assertNoCallback();
-
         mCm.unregisterNetworkCallback(defaultCallback);
     }
 
@@ -4430,4 +4447,68 @@
 
         mMockVpn.disconnect();
     }
+
+    @Test
+    public void testNullUnderlyingNetworks() {
+        final int uid = Process.myUid();
+
+        final TestNetworkCallback vpnNetworkCallback = new TestNetworkCallback();
+        final NetworkRequest vpnNetworkRequest = new NetworkRequest.Builder()
+                .removeCapability(NET_CAPABILITY_NOT_VPN)
+                .addTransportType(TRANSPORT_VPN)
+                .build();
+        NetworkCapabilities nc;
+        mCm.registerNetworkCallback(vpnNetworkRequest, vpnNetworkCallback);
+        vpnNetworkCallback.assertNoCallback();
+
+        final MockNetworkAgent vpnNetworkAgent = new MockNetworkAgent(TRANSPORT_VPN);
+        final ArraySet<UidRange> ranges = new ArraySet<>();
+        ranges.add(new UidRange(uid, uid));
+        mMockVpn.setNetworkAgent(vpnNetworkAgent);
+        mMockVpn.connect();
+        mMockVpn.setUids(ranges);
+        vpnNetworkAgent.connect(true /* validated */, false /* hasInternet */);
+
+        vpnNetworkCallback.expectAvailableThenValidatedCallbacks(vpnNetworkAgent);
+        nc = mCm.getNetworkCapabilities(vpnNetworkAgent.getNetwork());
+        assertTrue(nc.hasTransport(TRANSPORT_VPN));
+        assertFalse(nc.hasTransport(TRANSPORT_CELLULAR));
+        assertFalse(nc.hasTransport(TRANSPORT_WIFI));
+        // By default, VPN is set to track default network (i.e. its underlying networks is null).
+        // In case of no default network, VPN is considered metered.
+        assertFalse(nc.hasCapability(NET_CAPABILITY_NOT_METERED));
+
+        // Connect to Cell; Cell is the default network.
+        mCellNetworkAgent = new MockNetworkAgent(TRANSPORT_CELLULAR);
+        mCellNetworkAgent.connect(true);
+
+        vpnNetworkCallback.expectCapabilitiesLike((caps) -> caps.hasTransport(TRANSPORT_VPN)
+                && caps.hasTransport(TRANSPORT_CELLULAR) && !caps.hasTransport(TRANSPORT_WIFI)
+                && !caps.hasCapability(NET_CAPABILITY_NOT_METERED),
+                vpnNetworkAgent);
+
+        // Connect to WiFi; WiFi is the new default.
+        mWiFiNetworkAgent = new MockNetworkAgent(TRANSPORT_WIFI);
+        mWiFiNetworkAgent.addCapability(NET_CAPABILITY_NOT_METERED);
+        mWiFiNetworkAgent.connect(true);
+
+        vpnNetworkCallback.expectCapabilitiesLike((caps) -> caps.hasTransport(TRANSPORT_VPN)
+                && !caps.hasTransport(TRANSPORT_CELLULAR) && caps.hasTransport(TRANSPORT_WIFI)
+                && caps.hasCapability(NET_CAPABILITY_NOT_METERED),
+                vpnNetworkAgent);
+
+        // Disconnect Cell. The default network did not change, so there shouldn't be any changes in
+        // the capabilities.
+        mCellNetworkAgent.disconnect();
+
+        // Disconnect wifi too. Now we have no default network.
+        mWiFiNetworkAgent.disconnect();
+
+        vpnNetworkCallback.expectCapabilitiesLike((caps) -> caps.hasTransport(TRANSPORT_VPN)
+                && !caps.hasTransport(TRANSPORT_CELLULAR) && !caps.hasTransport(TRANSPORT_WIFI)
+                && !caps.hasCapability(NET_CAPABILITY_NOT_METERED),
+                vpnNetworkAgent);
+
+        mMockVpn.disconnect();
+    }
 }
diff --git a/tests/net/java/com/android/server/connectivity/VpnTest.java b/tests/net/java/com/android/server/connectivity/VpnTest.java
index e377a47..a0a4ad1 100644
--- a/tests/net/java/com/android/server/connectivity/VpnTest.java
+++ b/tests/net/java/com/android/server/connectivity/VpnTest.java
@@ -457,7 +457,8 @@
 
         final NetworkCapabilities caps = new NetworkCapabilities();
 
-        Vpn.updateCapabilities(mConnectivityManager, new Network[] { }, caps);
+        Vpn.applyUnderlyingCapabilities(
+                mConnectivityManager, new Network[] {}, caps);
         assertTrue(caps.hasTransport(TRANSPORT_VPN));
         assertFalse(caps.hasTransport(TRANSPORT_CELLULAR));
         assertFalse(caps.hasTransport(TRANSPORT_WIFI));
@@ -467,7 +468,8 @@
         assertTrue(caps.hasCapability(NET_CAPABILITY_NOT_ROAMING));
         assertTrue(caps.hasCapability(NET_CAPABILITY_NOT_CONGESTED));
 
-        Vpn.updateCapabilities(mConnectivityManager, new Network[] { mobile }, caps);
+        Vpn.applyUnderlyingCapabilities(
+                mConnectivityManager, new Network[] {mobile}, caps);
         assertTrue(caps.hasTransport(TRANSPORT_VPN));
         assertTrue(caps.hasTransport(TRANSPORT_CELLULAR));
         assertFalse(caps.hasTransport(TRANSPORT_WIFI));
@@ -477,7 +479,8 @@
         assertFalse(caps.hasCapability(NET_CAPABILITY_NOT_ROAMING));
         assertTrue(caps.hasCapability(NET_CAPABILITY_NOT_CONGESTED));
 
-        Vpn.updateCapabilities(mConnectivityManager, new Network[] { wifi }, caps);
+        Vpn.applyUnderlyingCapabilities(
+                mConnectivityManager, new Network[] {wifi}, caps);
         assertTrue(caps.hasTransport(TRANSPORT_VPN));
         assertFalse(caps.hasTransport(TRANSPORT_CELLULAR));
         assertTrue(caps.hasTransport(TRANSPORT_WIFI));
@@ -487,7 +490,8 @@
         assertTrue(caps.hasCapability(NET_CAPABILITY_NOT_ROAMING));
         assertTrue(caps.hasCapability(NET_CAPABILITY_NOT_CONGESTED));
 
-        Vpn.updateCapabilities(mConnectivityManager, new Network[] { mobile, wifi }, caps);
+        Vpn.applyUnderlyingCapabilities(
+                mConnectivityManager, new Network[] {mobile, wifi}, caps);
         assertTrue(caps.hasTransport(TRANSPORT_VPN));
         assertTrue(caps.hasTransport(TRANSPORT_CELLULAR));
         assertTrue(caps.hasTransport(TRANSPORT_WIFI));