Close sockets when changing network permissions.

Bug: 23113288
Change-Id: I8dcb02c79c81244e5b7288cb50770ac6a5867fcc
diff --git a/server/NetworkController.cpp b/server/NetworkController.cpp
index 014d926..aaf8b29 100644
--- a/server/NetworkController.cpp
+++ b/server/NetworkController.cpp
@@ -453,8 +453,6 @@
             return -EINVAL;
         }
 
-        // TODO: ioctl(SIOCKILLADDR, ...) to kill socets on the network that don't have permission.
-
         if (int ret = static_cast<PhysicalNetwork*>(network)->setPermission(permission)) {
             return ret;
         }
diff --git a/server/PhysicalNetwork.cpp b/server/PhysicalNetwork.cpp
index 495a93a..ee0e7c7 100644
--- a/server/PhysicalNetwork.cpp
+++ b/server/PhysicalNetwork.cpp
@@ -17,6 +17,7 @@
 #include "PhysicalNetwork.h"
 
 #include "RouteController.h"
+#include "SockDiag.h"
 
 #define LOG_TAG "Netd"
 #include "log/log.h"
@@ -65,10 +66,33 @@
     return mPermission;
 }
 
+int PhysicalNetwork::destroySocketsLackingPermission(Permission permission) {
+    if (permission == PERMISSION_NONE) return 0;
+
+    SockDiag sd;
+    if (!sd.open()) {
+       ALOGE("Error closing sockets for netId %d permission change", mNetId);
+       return -EBADFD;
+    }
+    if (int ret = sd.destroySocketsLackingPermission(mNetId, permission,
+                                                     true /* excludeLoopback */)) {
+        ALOGE("Failed to close sockets changing netId %d to permission %d: %s",
+              mNetId, permission, strerror(-ret));
+        return ret;
+    }
+    return 0;
+}
+
 int PhysicalNetwork::setPermission(Permission permission) {
     if (permission == mPermission) {
         return 0;
     }
+    if (mInterfaces.empty()) {
+        mPermission = permission;
+        return 0;
+    }
+
+    destroySocketsLackingPermission(permission);
     for (const std::string& interface : mInterfaces) {
         if (int ret = RouteController::modifyPhysicalNetworkPermission(mNetId, interface.c_str(),
                                                                        mPermission, permission)) {
@@ -87,6 +111,10 @@
             }
         }
     }
+    // Destroy sockets again in case any were opened after we called destroySocketsLackingPermission
+    // above and before we changed the permissions. These sockets won't be able to send any RST
+    // packets because they are now no longer routed, but at least the apps will get errors.
+    destroySocketsLackingPermission(permission);
     mPermission = permission;
     return 0;
 }
diff --git a/server/PhysicalNetwork.h b/server/PhysicalNetwork.h
index 2ef10df..cba3c6e 100644
--- a/server/PhysicalNetwork.h
+++ b/server/PhysicalNetwork.h
@@ -46,6 +46,7 @@
     Type getType() const override;
     int addInterface(const std::string& interface) override WARN_UNUSED_RESULT;
     int removeInterface(const std::string& interface) override WARN_UNUSED_RESULT;
+    int destroySocketsLackingPermission(Permission permission);
 
     Delegate* const mDelegate;
     Permission mPermission;
diff --git a/server/SockDiag.cpp b/server/SockDiag.cpp
index e91c6d1..1517b94 100644
--- a/server/SockDiag.cpp
+++ b/server/SockDiag.cpp
@@ -31,7 +31,9 @@
 #include <android-base/strings.h>
 #include <cutils/log.h>
 
+#include "Fwmark.h"
 #include "NetdConstants.h"
+#include "Permission.h"
 #include "SockDiag.h"
 
 #include <chrono>
@@ -40,6 +42,8 @@
 #define SOCK_DESTROY 21
 #endif
 
+#define INET_DIAG_BC_MARK_COND 10
+
 namespace {
 
 int checkError(int fd) {
@@ -186,9 +190,9 @@
     attrs.nla.nla_len = sizeof(attrs) + addrlen;
 
     iovec iov[] = {
-        { nullptr, 0 },
-        { &attrs, sizeof(attrs) },
-        { addr, addrlen },
+        { nullptr,           0 },
+        { &attrs,            sizeof(attrs) },
+        { addr,              addrlen },
     };
 
     uint32_t states = ~(1 << TCP_TIME_WAIT);
@@ -316,18 +320,19 @@
     return mSocketsDestroyed;
 }
 
-int SockDiag::destroyLiveSockets(DumpCallback destroyFilter) {
+int SockDiag::destroyLiveSockets(DumpCallback destroyFilter, const char *what,
+                                 iovec *iov, int iovcnt) {
     int proto = IPPROTO_TCP;
 
     for (const int family : {AF_INET, AF_INET6}) {
         const char *familyName = (family == AF_INET) ? "IPv4" : "IPv6";
         uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
-        if (int ret = sendDumpRequest(proto, family, states)) {
-            ALOGE("Failed to dump %s sockets for UID: %s", familyName, strerror(-ret));
+        if (int ret = sendDumpRequest(proto, family, states, iov, iovcnt)) {
+            ALOGE("Failed to dump %s sockets for %s: %s", familyName, what, strerror(-ret));
             return ret;
         }
         if (int ret = readDiagMsg(proto, destroyFilter)) {
-            ALOGE("Failed to destroy %s sockets for UID: %s", familyName, strerror(-ret));
+            ALOGE("Failed to destroy %s sockets for %s: %s", familyName, what, strerror(-ret));
             return ret;
         }
     }
@@ -377,7 +382,11 @@
                !(excludeLoopback && isLoopbackSocket(msg));
     };
 
-    if (int ret = destroyLiveSockets(shouldDestroy)) {
+    iovec iov[] = {
+        { nullptr, 0 },
+    };
+
+    if (int ret = destroyLiveSockets(shouldDestroy, "UID", iov, ARRAY_SIZE(iov))) {
         return ret;
     }
 
@@ -395,3 +404,95 @@
 
     return 0;
 }
+
+// Destroys all "live" (CONNECTED, SYN_SENT, SYN_RECV) TCP sockets on the specified netId where:
+// 1. The opening app no longer has permission to use this network, or:
+// 2. The opening app does have permission, but did not explicitly select this network.
+//
+// We destroy sockets without the explicit bit because we want to avoid the situation where a
+// privileged app uses its privileges without knowing it is doing so. For example, a privileged app
+// might have opened a socket on this network just because it was the default network at the
+// time. If we don't kill these sockets, those apps could continue to use them without realizing
+// that they are now sending and receiving traffic on a network that is now restricted.
+int SockDiag::destroySocketsLackingPermission(unsigned netId, Permission permission,
+                                              bool excludeLoopback) {
+    struct markmatch {
+        inet_diag_bc_op op;
+        // TODO: switch to inet_diag_markcond
+        __u32 mark;
+        __u32 mask;
+    } __attribute__((packed));
+    constexpr uint8_t matchlen = sizeof(markmatch);
+
+    Fwmark netIdMark, netIdMask;
+    netIdMark.netId = netId;
+    netIdMask.netId = 0xffff;
+
+    Fwmark controlMark;
+    controlMark.explicitlySelected = true;
+    controlMark.permission = permission;
+
+    // A SOCK_DIAG bytecode program that accepts the sockets we intend to destroy.
+    struct bytecode {
+        markmatch netIdMatch;
+        markmatch controlMatch;
+        inet_diag_bc_op controlJump;
+    } __attribute__((packed)) bytecode;
+
+    // The length of the INET_DIAG_BC_JMP instruction.
+    constexpr uint8_t jmplen = sizeof(inet_diag_bc_op);
+    // Jump exactly this far past the end of the program to reject.
+    constexpr uint8_t rejectoffset = sizeof(inet_diag_bc_op);
+    // Total length of the program.
+    constexpr uint8_t bytecodelen = sizeof(bytecode);
+
+    bytecode = (struct bytecode) {
+        // If netId matches, continue, otherwise, reject (i.e., leave socket alone).
+        { { INET_DIAG_BC_MARK_COND, matchlen, bytecodelen + rejectoffset },
+          netIdMark.intValue, netIdMask.intValue },
+
+        // If explicit and permission bits match, go to the JMP below which rejects the socket
+        // (i.e., we leave it alone). Otherwise, jump to the end of the program, which accepts the
+        // socket (so we destroy it).
+        { { INET_DIAG_BC_MARK_COND, matchlen, matchlen + jmplen },
+          controlMark.intValue, controlMark.intValue },
+
+        // This JMP unconditionally rejects the packet by jumping to the reject target. It is
+        // necessary to keep the kernel bytecode verifier happy. If we don't have a JMP the bytecode
+        // is invalid because the target of every no jump must always be reachable by yes jumps.
+        // Without this JMP, the accept target is not reachable by yes jumps and the program will
+        // be rejected by the validator.
+        { INET_DIAG_BC_JMP, jmplen, jmplen + rejectoffset },
+
+        // We have reached the end of the program. Accept the socket, and destroy it below.
+    };
+
+    struct nlattr nla = {
+        .nla_type = INET_DIAG_REQ_BYTECODE,
+        .nla_len = sizeof(struct nlattr) + bytecodelen,
+    };
+
+    iovec iov[] = {
+        { nullptr,   0 },
+        { &nla,      sizeof(nla) },
+        { &bytecode, bytecodelen },
+    };
+
+    mSocketsDestroyed = 0;
+    Stopwatch s;
+
+    auto shouldDestroy = [&] (uint8_t, const inet_diag_msg *msg) {
+        return msg != nullptr && !(excludeLoopback && isLoopbackSocket(msg));
+    };
+
+    if (int ret = destroyLiveSockets(shouldDestroy, "permission change", iov, ARRAY_SIZE(iov))) {
+        return ret;
+    }
+
+    if (mSocketsDestroyed > 0) {
+        ALOGI("Destroyed %d sockets for netId %d permission=%d in %.1f ms",
+              mSocketsDestroyed, netId, permission, s.timeTaken());
+    }
+
+    return 0;
+}
diff --git a/server/SockDiag.h b/server/SockDiag.h
index 5dc77c1..e561561 100644
--- a/server/SockDiag.h
+++ b/server/SockDiag.h
@@ -24,6 +24,7 @@
 #include <functional>
 #include <set>
 
+#include "Permission.h"
 #include "UidRanges.h"
 
 struct inet_diag_msg;
@@ -58,6 +59,10 @@
     // Destroys all "live" (CONNECTED, SYN_SENT, SYN_RECV) TCP sockets for the given UID ranges.
     int destroySockets(const UidRanges& uidRanges, const std::set<uid_t>& skipUids,
                        bool excludeLoopback);
+    // Destroys all "live" (CONNECTED, SYN_SENT, SYN_RECV) TCP sockets that no longer have
+    // the permissions required by the specified network.
+    int destroySocketsLackingPermission(unsigned netId, Permission permission,
+                                        bool excludeLoopback);
 
   private:
     friend class SockDiagTest;
@@ -66,7 +71,7 @@
     int mSocketsDestroyed;
     int sendDumpRequest(uint8_t proto, uint8_t family, uint32_t states, iovec *iov, int iovcnt);
     int destroySockets(uint8_t proto, int family, const char *addrstr);
-    int destroyLiveSockets(DumpCallback destroy);
+    int destroyLiveSockets(DumpCallback destroy, const char *what, iovec *iov, int iovcnt);
     bool hasSocks() { return mSock != -1 && mWriteSock != -1; }
     void closeSocks() { close(mSock); close(mWriteSock); mSock = mWriteSock = -1; }
     static bool isLoopbackSocket(const inet_diag_msg *msg);
diff --git a/server/SockDiagTest.cpp b/server/SockDiagTest.cpp
index f9353f3..70a199d 100644
--- a/server/SockDiagTest.cpp
+++ b/server/SockDiagTest.cpp
@@ -25,6 +25,7 @@
 
 #include <gtest/gtest.h>
 
+#include "Fwmark.h"
 #include "NetdConstants.h"
 #include "SockDiag.h"
 #include "UidRanges.h"
@@ -278,6 +279,7 @@
     UID_EXCLUDE_LOOPBACK,
     UIDRANGE,
     UIDRANGE_EXCLUDE_LOOPBACK,
+    PERMISSION,
 };
 
 const char *testTypeName(MicroBenchmarkTestType mode) {
@@ -288,10 +290,30 @@
         TO_STRING_TYPE(UID_EXCLUDE_LOOPBACK);
         TO_STRING_TYPE(UIDRANGE);
         TO_STRING_TYPE(UIDRANGE_EXCLUDE_LOOPBACK);
+        TO_STRING_TYPE(PERMISSION);
     }
 #undef TO_STRING_TYPE
 }
 
+static struct {
+    unsigned netId;
+    bool explicitlySelected;
+    Permission permission;
+} permissionTestcases[] = {
+    { 42, false, PERMISSION_NONE,    },
+    { 42, false, PERMISSION_NETWORK, },
+    { 42, false, PERMISSION_SYSTEM,  },
+    { 42, true,  PERMISSION_NONE,    },
+    { 42, true,  PERMISSION_NETWORK, },
+    { 42, true,  PERMISSION_SYSTEM,  },
+    { 43, false, PERMISSION_NONE,    },
+    { 43, false, PERMISSION_NETWORK, },
+    { 43, false, PERMISSION_SYSTEM,  },
+    { 43, true,  PERMISSION_NONE,    },
+    { 43, true,  PERMISSION_NETWORK, },
+    { 43, true,  PERMISSION_SYSTEM,  },
+};
+
 class SockDiagMicroBenchmarkTest : public ::testing::TestWithParam<MicroBenchmarkTestType> {
 
 public:
@@ -305,10 +327,15 @@
     constexpr static int MAX_SOCKETS = 500;
     constexpr static int ADDRESS_SOCKETS = 500;
     constexpr static int UID_SOCKETS = 50;
+    constexpr static int PERMISSION_SOCKETS = 16;
+
     constexpr static uid_t START_UID = 8000;  // START_UID + number of sockets must be <= 9999.
     constexpr static int CLOSE_UID = START_UID + UID_SOCKETS - 42;  // Close to the end
     static_assert(START_UID + MAX_SOCKETS < 9999, "Too many sockets");
 
+    constexpr static int TEST_NETID = 42;  // One of the OEM netIds.
+
+
     int howManySockets() {
         MicroBenchmarkTestType mode = GetParam();
         switch (mode) {
@@ -319,6 +346,30 @@
         case UIDRANGE:
         case UIDRANGE_EXCLUDE_LOOPBACK:
             return UID_SOCKETS;
+        case PERMISSION:
+            return ARRAY_SIZE(permissionTestcases);
+        }
+    }
+
+    int modifySocketForTest(int s, int i) {
+        MicroBenchmarkTestType mode = GetParam();
+        switch (mode) {
+        case UID:
+        case UID_EXCLUDE_LOOPBACK:
+        case UIDRANGE:
+        case UIDRANGE_EXCLUDE_LOOPBACK: {
+            uid_t uid = START_UID + i;
+            return fchown(s, uid, -1);
+        }
+        case PERMISSION: {
+            Fwmark fwmark;
+            fwmark.netId = permissionTestcases[i].netId;
+            fwmark.explicitlySelected = permissionTestcases[i].explicitlySelected;
+            fwmark.permission = permissionTestcases[i].permission;
+            return setsockopt(s, SOL_SOCKET, SO_MARK, &fwmark.intValue, sizeof(fwmark.intValue));
+        }
+        default:
+            return 0;
         }
     }
 
@@ -346,6 +397,11 @@
                 UidRanges uidRanges;
                 uidRanges.parseFrom(ARRAY_SIZE(uidRangeStrings), (char **) uidRangeStrings);
                 ret = mSd.destroySockets(uidRanges, skipUids, excludeLoopback);
+                break;
+            }
+            case PERMISSION: {
+                ret = mSd.destroySocketsLackingPermission(TEST_NETID, PERMISSION_NETWORK, false);
+                break;
             }
         }
         return ret;
@@ -373,6 +429,11 @@
             case UID_EXCLUDE_LOOPBACK:
             case UIDRANGE_EXCLUDE_LOOPBACK:
                 return false;
+            case PERMISSION:
+                if (permissionTestcases[i].netId != 42) return false;
+                if (permissionTestcases[i].explicitlySelected != 1) return true;
+                Permission permission = permissionTestcases[i].permission;
+                return permission != PERMISSION_NETWORK && permission != PERMISSION_SYSTEM;
         }
     }
 
@@ -424,11 +485,10 @@
     auto start = std::chrono::steady_clock::now();
     for (int i = 0; i < numSockets; i++) {
         int s = socket(AF_INET6, SOCK_STREAM, 0);
-        uid_t uid = START_UID + i;
-        ASSERT_EQ(0, fchown(s, uid, -1));
         clientlen = sizeof(client);
         ASSERT_EQ(0, connect(s, (sockaddr *) &server, sizeof(server)))
             << "Connecting socket " << i << " failed " << strerror(errno);
+        ASSERT_EQ(0, modifySocketForTest(s, i));
         serversockets[i] = accept(listensocket, (sockaddr *) &client, &clientlen);
         ASSERT_NE(-1, serversockets[i])
             << "Accepting socket " << i << " failed " << strerror(errno);
@@ -472,4 +532,5 @@
 
 INSTANTIATE_TEST_CASE_P(Address, SockDiagMicroBenchmarkTest,
                         testing::Values(ADDRESS, UID, UIDRANGE,
-                                        UID_EXCLUDE_LOOPBACK, UIDRANGE_EXCLUDE_LOOPBACK));
+                                        UID_EXCLUDE_LOOPBACK, UIDRANGE_EXCLUDE_LOOPBACK,
+                                        PERMISSION));