Merge "Let ConnectivityService control the socket closure"
diff --git a/services/core/java/com/android/server/ConnectivityService.java b/services/core/java/com/android/server/ConnectivityService.java
index 020c17a..4f0b34d 100644
--- a/services/core/java/com/android/server/ConnectivityService.java
+++ b/services/core/java/com/android/server/ConnectivityService.java
@@ -56,6 +56,7 @@
 import static android.net.NetworkPolicyManager.uidRulesToString;
 import static android.net.shared.NetworkMonitorUtils.isPrivateDnsValidationRequired;
 import static android.os.Process.INVALID_UID;
+import static android.os.Process.VPN_UID;
 import static android.system.OsConstants.IPPROTO_TCP;
 import static android.system.OsConstants.IPPROTO_UDP;
 
@@ -6675,6 +6676,39 @@
         return stableRanges;
     }
 
+    private void maybeCloseSockets(NetworkAgentInfo nai, UidRangeParcel[] ranges,
+            int[] exemptUids) {
+        if (nai.isVPN() && !nai.networkAgentConfig.allowBypass) {
+            try {
+                mNetd.socketDestroy(ranges, exemptUids);
+            } catch (Exception e) {
+                loge("Exception in socket destroy: ", e);
+            }
+        }
+    }
+
+    private void updateUidRanges(boolean add, NetworkAgentInfo nai, Set<UidRange> uidRanges) {
+        int[] exemptUids = new int[2];
+        // TODO: Excluding VPN_UID is necessary in order to not to kill the TCP connection used
+        // by PPTP. Fix this by making Vpn set the owner UID to VPN_UID instead of system when
+        // starting a legacy VPN, and remove VPN_UID here. (b/176542831)
+        exemptUids[0] = VPN_UID;
+        exemptUids[1] = nai.networkCapabilities.getOwnerUid();
+        UidRangeParcel[] ranges = toUidRangeStableParcels(uidRanges);
+
+        maybeCloseSockets(nai, ranges, exemptUids);
+        try {
+            if (add) {
+                mNetd.networkAddUidRanges(nai.network.netId, ranges);
+            } else {
+                mNetd.networkRemoveUidRanges(nai.network.netId, ranges);
+            }
+        } catch (Exception e) {
+            loge("Exception while " + (add ? "adding" : "removing") + " uid ranges " + uidRanges +
+                    " on netId " + nai.network.netId + ". " + e);
+        }
+        maybeCloseSockets(nai, ranges, exemptUids);
+    }
 
     private void updateUids(NetworkAgentInfo nai, NetworkCapabilities prevNc,
             NetworkCapabilities newNc) {
@@ -6694,12 +6728,21 @@
             // in both ranges are not subject to any VPN routing rules. Adding new range before
             // removing old range works because, unlike the filtering rules below, it's possible to
             // add duplicate UID routing rules.
+            // TODO: calculate the intersection of add & remove. Imagining that we are trying to
+            // remove uid 3 from a set containing 1-5. Intersection of the prev and new sets is:
+            //   [1-5] & [1-2],[4-5] == [3]
+            // Then we can do:
+            //   maybeCloseSockets([3])
+            //   mNetd.networkAddUidRanges([1-2],[4-5])
+            //   mNetd.networkRemoveUidRanges([1-5])
+            //   maybeCloseSockets([3])
+            // This can prevent the sockets of uid 1-2, 4-5 from being closed. It also reduce the
+            // number of binder calls from 6 to 4.
             if (!newRanges.isEmpty()) {
-                mNetd.networkAddUidRanges(nai.network.netId, toUidRangeStableParcels(newRanges));
+                updateUidRanges(true, nai, newRanges);
             }
             if (!prevRanges.isEmpty()) {
-                mNetd.networkRemoveUidRanges(
-                        nai.network.netId, toUidRangeStableParcels(prevRanges));
+                updateUidRanges(false, nai, prevRanges);
             }
             final boolean wasFiltering = requiresVpnIsolation(nai, prevNc, nai.linkProperties);
             final boolean shouldFilter = requiresVpnIsolation(nai, newNc, nai.linkProperties);
diff --git a/tests/net/java/com/android/server/ConnectivityServiceTest.java b/tests/net/java/com/android/server/ConnectivityServiceTest.java
index 24b1343..1df7855 100644
--- a/tests/net/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/net/java/com/android/server/ConnectivityServiceTest.java
@@ -3363,6 +3363,7 @@
         assertEquals(null, mCm.getActiveNetwork());
 
         mMockVpn.establishForMyUid();
+        assertUidRangesUpdatedForMyUid(true);
         defaultNetworkCallback.expectAvailableThenValidatedCallbacks(mMockVpn);
         assertEquals(defaultNetworkCallback.getLastAvailableNetwork(), mCm.getActiveNetwork());
 
@@ -5053,6 +5054,7 @@
         lp.setInterfaceName(VPN_IFNAME);
 
         mMockVpn.establishForMyUid(lp);
+        assertUidRangesUpdatedForMyUid(true);
 
         final Network[] cellAndVpn = new Network[] {
                 mCellNetworkAgent.getNetwork(), mMockVpn.getNetwork()};
@@ -5638,6 +5640,7 @@
         // (and doing so is difficult without using reflection) but it's good to test that the code
         // behaves approximately correctly.
         mMockVpn.establishForMyUid(false, true, false);
+        assertUidRangesUpdatedForMyUid(true);
         final Network wifiNetwork = new Network(mNetIdManager.peekNextNetId());
         mService.setUnderlyingNetworksForVpn(new Network[]{wifiNetwork});
         callback.expectAvailableCallbacksUnvalidated(mMockVpn);
@@ -5795,6 +5798,7 @@
 
         mMockVpn.establishForMyUid(true /* validated */, false /* hasInternet */,
                 false /* isStrictMode */);
+        assertUidRangesUpdatedForMyUid(true);
 
         defaultCallback.assertNoCallback();
         assertEquals(defaultCallback.getLastAvailableNetwork(), mCm.getActiveNetwork());
@@ -5820,6 +5824,7 @@
 
         mMockVpn.establishForMyUid(true /* validated */, true /* hasInternet */,
                 false /* isStrictMode */);
+        assertUidRangesUpdatedForMyUid(true);
 
         defaultCallback.expectAvailableThenValidatedCallbacks(mMockVpn);
         assertEquals(defaultCallback.getLastAvailableNetwork(), mCm.getActiveNetwork());
@@ -5845,6 +5850,7 @@
         // Bring up a VPN that has the INTERNET capability, initially unvalidated.
         mMockVpn.establishForMyUid(false /* validated */, true /* hasInternet */,
                 false /* isStrictMode */);
+        assertUidRangesUpdatedForMyUid(true);
 
         // Even though the VPN is unvalidated, it becomes the default network for our app.
         callback.expectAvailableCallbacksUnvalidated(mMockVpn);
@@ -5896,6 +5902,7 @@
 
         mMockVpn.establishForMyUid(true /* validated */, false /* hasInternet */,
                 false /* isStrictMode */);
+        assertUidRangesUpdatedForMyUid(true);
 
         vpnNetworkCallback.expectAvailableCallbacks(mMockVpn.getNetwork(),
                 false /* suspended */, false /* validated */, false /* blocked */, TIMEOUT_MS);
@@ -5937,6 +5944,7 @@
 
         mMockVpn.establishForMyUid(true /* validated */, false /* hasInternet */,
                 false /* isStrictMode */);
+        assertUidRangesUpdatedForMyUid(true);
 
         vpnNetworkCallback.expectAvailableThenValidatedCallbacks(mMockVpn);
         nc = mCm.getNetworkCapabilities(mMockVpn.getNetwork());
@@ -6104,6 +6112,7 @@
 
         mMockVpn.establishForMyUid(true /* validated */, false /* hasInternet */,
                 false /* isStrictMode */);
+        assertUidRangesUpdatedForMyUid(true);
 
         vpnNetworkCallback.expectAvailableThenValidatedCallbacks(mMockVpn);
         nc = mCm.getNetworkCapabilities(mMockVpn.getNetwork());
@@ -6162,6 +6171,7 @@
 
         // Bring up a VPN
         mMockVpn.establishForMyUid();
+        assertUidRangesUpdatedForMyUid(true);
         callback.expectAvailableThenValidatedCallbacks(mMockVpn);
         callback.assertNoCallback();
 
@@ -6316,6 +6326,7 @@
 
         // Connect VPN network. By default it is using current default network (Cell).
         mMockVpn.establishForMyUid();
+        assertUidRangesUpdatedForMyUid(true);
 
         // Ensure VPN is now the active network.
         assertEquals(mMockVpn.getNetwork(), mCm.getActiveNetwork());
@@ -6368,6 +6379,7 @@
 
         // Connect VPN network.
         mMockVpn.establishForMyUid();
+        assertUidRangesUpdatedForMyUid(true);
 
         // Ensure VPN is now the active network.
         assertEquals(mMockVpn.getNetwork(), mCm.getActiveNetwork());
@@ -6742,6 +6754,7 @@
         assertNetworkInfo(TYPE_WIFI, DetailedState.BLOCKED);
 
         mMockVpn.establishForMyUid();
+        assertUidRangesUpdatedForMyUid(true);
         defaultCallback.expectAvailableThenValidatedCallbacks(mMockVpn);
         vpnUidCallback.assertNoCallback();  // vpnUidCallback has NOT_VPN capability.
         assertEquals(mMockVpn.getNetwork(), mCm.getActiveNetwork());
@@ -7399,6 +7412,7 @@
         LinkProperties testLinkProperties = new LinkProperties();
         testLinkProperties.setHttpProxy(testProxyInfo);
         mMockVpn.establishForMyUid(testLinkProperties);
+        assertUidRangesUpdatedForMyUid(true);
 
         // Test that the VPN network returns a proxy, and the WiFi does not.
         assertEquals(testProxyInfo, mService.getProxyForNetwork(mMockVpn.getNetwork()));
@@ -7436,6 +7450,7 @@
         // The uid range needs to cover the test app so the network is visible to it.
         final Set<UidRange> vpnRange = Collections.singleton(UidRange.createForUser(VPN_USER));
         mMockVpn.establish(lp, VPN_UID, vpnRange);
+        assertVpnUidRangesUpdated(true, vpnRange, VPN_UID);
 
         // A connected VPN should have interface rules set up. There are two expected invocations,
         // one during the VPN initial connection, one during the VPN LinkProperties update.
@@ -7463,6 +7478,7 @@
         // The uid range needs to cover the test app so the network is visible to it.
         final Set<UidRange> vpnRange = Collections.singleton(UidRange.createForUser(VPN_USER));
         mMockVpn.establish(lp, Process.SYSTEM_UID, vpnRange);
+        assertVpnUidRangesUpdated(true, vpnRange, Process.SYSTEM_UID);
 
         // Legacy VPN should not have interface rules set up
         verify(mMockNetd, never()).firewallAddUidInterfaceRules(any(), any());
@@ -7478,6 +7494,7 @@
         // The uid range needs to cover the test app so the network is visible to it.
         final Set<UidRange> vpnRange = Collections.singleton(UidRange.createForUser(VPN_USER));
         mMockVpn.establish(lp, Process.SYSTEM_UID, vpnRange);
+        assertVpnUidRangesUpdated(true, vpnRange, Process.SYSTEM_UID);
 
         // IPv6 unreachable route should not be misinterpreted as a default route
         verify(mMockNetd, never()).firewallAddUidInterfaceRules(any(), any());
@@ -7492,6 +7509,7 @@
         // The uid range needs to cover the test app so the network is visible to it.
         final Set<UidRange> vpnRange = Collections.singleton(UidRange.createForUser(VPN_USER));
         mMockVpn.establish(lp, VPN_UID, vpnRange);
+        assertVpnUidRangesUpdated(true, vpnRange, VPN_UID);
 
         // Connected VPN should have interface rules set up. There are two expected invocations,
         // one during VPN uid update, one during VPN LinkProperties update
@@ -7542,7 +7560,9 @@
         lp.addRoute(new RouteInfo(new IpPrefix(Inet6Address.ANY, 0), null));
         // The uid range needs to cover the test app so the network is visible to it.
         final UidRange vpnRange = UidRange.createForUser(VPN_USER);
-        mMockVpn.establish(lp, VPN_UID, Collections.singleton(vpnRange));
+        final Set<UidRange> vpnRanges = Collections.singleton(vpnRange);
+        mMockVpn.establish(lp, VPN_UID, vpnRanges);
+        assertVpnUidRangesUpdated(true, vpnRanges, VPN_UID);
 
         reset(mMockNetd);
         InOrder inOrder = inOrder(mMockNetd);
@@ -7693,6 +7713,7 @@
             throws Exception {
         final Set<UidRange> vpnRange = Collections.singleton(UidRange.createForUser(VPN_USER));
         mMockVpn.establish(new LinkProperties(), vpnOwnerUid, vpnRange);
+        assertVpnUidRangesUpdated(true, vpnRange, vpnOwnerUid);
         mMockVpn.setVpnType(vpnType);
 
         final VpnInfo vpnInfo = new VpnInfo();
@@ -7950,6 +7971,7 @@
                 Manifest.permission.ACCESS_FINE_LOCATION);
 
         mMockVpn.establishForMyUid();
+        assertUidRangesUpdatedForMyUid(true);
 
         // Wait for networks to connect and broadcasts to be sent before removing permissions.
         waitForIdle();
@@ -8229,4 +8251,54 @@
             assertTrue(isRequestIdInOrder);
         }
     }
+
+    private void assertUidRangesUpdatedForMyUid(boolean add) throws Exception {
+        final int uid = Process.myUid();
+        assertVpnUidRangesUpdated(add, uidRangesForUid(uid), uid);
+    }
+
+    private void assertVpnUidRangesUpdated(boolean add, Set<UidRange> vpnRanges, int exemptUid)
+            throws Exception {
+        InOrder inOrder = inOrder(mMockNetd);
+        ArgumentCaptor<int[]> exemptUidCaptor = ArgumentCaptor.forClass(int[].class);
+
+        inOrder.verify(mMockNetd, times(1)).socketDestroy(eq(toUidRangeStableParcels(vpnRanges)),
+                exemptUidCaptor.capture());
+        assertContainsExactly(exemptUidCaptor.getValue(), Process.VPN_UID, exemptUid);
+
+        if (add) {
+            inOrder.verify(mMockNetd, times(1)).networkAddUidRanges(eq(mMockVpn.getNetId()),
+                    eq(toUidRangeStableParcels(vpnRanges)));
+        } else {
+            inOrder.verify(mMockNetd, times(1)).networkRemoveUidRanges(eq(mMockVpn.getNetId()),
+                    eq(toUidRangeStableParcels(vpnRanges)));
+        }
+
+        inOrder.verify(mMockNetd, times(1)).socketDestroy(eq(toUidRangeStableParcels(vpnRanges)),
+                exemptUidCaptor.capture());
+        assertContainsExactly(exemptUidCaptor.getValue(), Process.VPN_UID, exemptUid);
+    }
+
+    @Test
+    public void testVpnUidRangesUpdate() throws Exception {
+        LinkProperties lp = new LinkProperties();
+        lp.setInterfaceName("tun0");
+        lp.addRoute(new RouteInfo(new IpPrefix(Inet4Address.ANY, 0), null));
+        lp.addRoute(new RouteInfo(new IpPrefix(Inet6Address.ANY, 0), null));
+        final UidRange vpnRange = UidRange.createForUser(VPN_USER);
+        Set<UidRange> vpnRanges = Collections.singleton(vpnRange);
+        mMockVpn.establish(lp, VPN_UID, vpnRanges);
+        assertVpnUidRangesUpdated(true, vpnRanges, VPN_UID);
+
+        reset(mMockNetd);
+        // Update to new range which is old range minus APP1, i.e. only APP2
+        final Set<UidRange> newRanges = new HashSet<>(Arrays.asList(
+                new UidRange(vpnRange.start, APP1_UID - 1),
+                new UidRange(APP1_UID + 1, vpnRange.stop)));
+        mMockVpn.setUids(newRanges);
+        waitForIdle();
+
+        assertVpnUidRangesUpdated(true, newRanges, VPN_UID);
+        assertVpnUidRangesUpdated(false, vpnRanges, VPN_UID);
+    }
 }