Use SOCK_DESTROY in netd.

Bug: 26976388

(cherry picked from commit f32fc598b01ba8d59873b0a1085716fd84678b54)

Change-Id: I2e4d0018fdcee7106fc083a522d81dba87a4db40
diff --git a/server/NetlinkHandler.cpp b/server/NetlinkHandler.cpp
index 97dc3e0..718fbdb 100644
--- a/server/NetlinkHandler.cpp
+++ b/server/NetlinkHandler.cpp
@@ -29,6 +29,7 @@
 #include "NetlinkHandler.h"
 #include "NetlinkManager.h"
 #include "ResponseCode.h"
+#include "SockDiag.h"
 
 static const char *kUpdated = "updated";
 static const char *kRemoved = "removed";
@@ -78,14 +79,36 @@
             const char *flags = evt->findParam("FLAGS");
             const char *scope = evt->findParam("SCOPE");
             if (action == NetlinkEvent::Action::kAddressRemoved && iface && address) {
-                int resetMask = strchr(address, ':') ? RESET_IPV6_ADDRESSES : RESET_IPV4_ADDRESSES;
-                resetMask |= RESET_IGNORE_INTERFACE_ADDRESS;
-                if (int ret = ifc_reset_connections(iface, resetMask)) {
-                    ALOGE("ifc_reset_connections failed on iface %s for address %s (%s)", iface,
-                          address, strerror(ret));
+                // Note: if this interface was deleted, iface is "" and we don't notify.
+                SockDiag sd;
+                if (sd.open()) {
+                    char addrstr[INET6_ADDRSTRLEN];
+                    strncpy(addrstr, address, sizeof(addrstr));
+                    char *slash = strchr(addrstr, '/');
+                    if (slash) {
+                        *slash = '\0';
+                    }
+
+                    int ret = sd.destroySockets(addrstr);
+                    if (ret < 0) {
+                        ALOGE("Error destroying sockets: %s", strerror(ret));
+                    }
+                } else {
+                    ALOGE("Error opening NETLINK_SOCK_DIAG socket: %s", strerror(errno));
+                }
+
+                // TODO: delete this once SOCK_DESTROY works everywhere.
+                if (iface[0]) {
+                    int resetMask = strchr(address, ':') ?
+                            RESET_IPV6_ADDRESSES : RESET_IPV4_ADDRESSES;
+                    resetMask |= RESET_IGNORE_INTERFACE_ADDRESS;
+                    if (int ret = ifc_reset_connections(iface, resetMask)) {
+                        ALOGE("ifc_reset_connections failed on iface %s for address %s (%s)", iface,
+                              address, strerror(ret));
+                    }
                 }
             }
-            if (iface && flags && scope) {
+            if (iface && iface[0] && address && flags && scope) {
                 notifyAddressChanged(action, address, iface, flags, scope);
             }
         } else if (action == NetlinkEvent::Action::kRdnss) {
diff --git a/server/SockDiag.cpp b/server/SockDiag.cpp
index 2f1437c..b9f69cd 100644
--- a/server/SockDiag.cpp
+++ b/server/SockDiag.cpp
@@ -33,6 +33,8 @@
 #include "NetdConstants.h"
 #include "SockDiag.h"
 
+#include <chrono>
+
 #ifndef SOCK_DESTROY
 #define SOCK_DESTROY 21
 #endif
@@ -208,6 +210,10 @@
 }
 
 int SockDiag::sockDestroy(uint8_t proto, const inet_diag_msg *msg) {
+    if (msg == nullptr) {
+       return 0;
+    }
+
     DestroyRequest request = {
         .nlh = {
             .nlmsg_type = SOCK_DESTROY,
@@ -226,5 +232,47 @@
         return -errno;
     }
 
-    return checkError(mWriteSock);
+    int ret = checkError(mWriteSock);
+    if (!ret) mSocketsDestroyed++;
+    return ret;
+}
+
+int SockDiag::destroySockets(uint8_t proto, int family, const char *addrstr) {
+    if (!hasSocks()) {
+        return -EBADFD;
+    }
+
+    if (int ret = sendDumpRequest(proto, family, addrstr)) {
+        return ret;
+    }
+
+    auto destroy = [this] (uint8_t proto, const inet_diag_msg *msg) {
+        return this->sockDestroy(proto, msg);
+    };
+
+    return readDiagMsg(proto, destroy);
+}
+
+int SockDiag::destroySockets(const char *addrstr) {
+    using ms = std::chrono::duration<float, std::ratio<1, 1000>>;
+
+    mSocketsDestroyed = 0;
+    const auto start = std::chrono::steady_clock::now();
+    if (!strchr(addrstr, ':')) {
+        if (int ret = destroySockets(IPPROTO_TCP, AF_INET, addrstr)) {
+            ALOGE("Failed to destroy IPv4 sockets on %s: %s", addrstr, strerror(-ret));
+            return ret;
+        }
+    }
+    if (int ret = destroySockets(IPPROTO_TCP, AF_INET6, addrstr)) {
+        ALOGE("Failed to destroy IPv6 sockets on %s: %s", addrstr, strerror(-ret));
+        return ret;
+    }
+    auto elapsed = std::chrono::duration_cast<ms>(std::chrono::steady_clock::now() - start);
+
+    if (mSocketsDestroyed > 0) {
+        ALOGI("Destroyed %d sockets on %s in %.1f ms", mSocketsDestroyed, addrstr, elapsed.count());
+    }
+
+    return mSocketsDestroyed;
 }
diff --git a/server/SockDiag.h b/server/SockDiag.h
index 3b6ca8b..56acbdb 100644
--- a/server/SockDiag.h
+++ b/server/SockDiag.h
@@ -5,6 +5,7 @@
 #include <linux/inet_diag.h>
 
 struct inet_diag_msg;
+class SockDiagTest;
 
 class SockDiag {
 
@@ -17,17 +18,20 @@
         inet_diag_req_v2 req;
     } __attribute__((__packed__));
 
-    SockDiag() : mSock(-1), mWriteSock(-1) {}
+    SockDiag() : mSock(-1), mWriteSock(-1), mSocketsDestroyed(0) {}
     bool open();
     virtual ~SockDiag() { closeSocks(); }
 
     int sendDumpRequest(uint8_t proto, uint8_t family, const char *addrstr);
     int readDiagMsg(uint8_t proto, DumpCallback callback);
     int sockDestroy(uint8_t proto, const inet_diag_msg *);
+    int destroySockets(const char *addrstr);
 
   private:
     int mSock;
     int mWriteSock;
+    int mSocketsDestroyed;
+    int destroySockets(uint8_t proto, int family, const char *addrstr);
     bool hasSocks() { return mSock != -1 && mWriteSock != -1; }
     void closeSocks() { close(mSock); close(mWriteSock); mSock = mWriteSock = -1; }
 };
diff --git a/tests/sock_diag_test.cpp b/tests/sock_diag_test.cpp
index 67af978..8ee9908 100644
--- a/tests/sock_diag_test.cpp
+++ b/tests/sock_diag_test.cpp
@@ -213,18 +213,9 @@
     SockDiag sd;
     ASSERT_TRUE(sd.open()) << "Failed to open SOCK_DIAG socket";
 
-    int ret = sd.sendDumpRequest(IPPROTO_TCP, AF_INET6, "::1");
-    ASSERT_EQ(0, ret) << "Failed to send IPv6 dump request: " << strerror(-ret);
-
-    auto closeMySockets = [&] (uint8_t proto, const inet_diag_msg *msg) {
-        if (msg && msg->id.idiag_dport == htons(port)) {
-            return sd.sockDestroy(proto, msg);
-        }
-        return 0;
-    };
-
     start = std::chrono::steady_clock::now();
-    sd.readDiagMsg(IPPROTO_TCP, closeMySockets);
+    int ret = sd.destroySockets("::1");
+    EXPECT_LE(0, ret) << ": Failed to destroy sockets on ::1: " << strerror(-ret);
     fprintf(stderr, "  Destroying: %6.1f ms\n",
             std::chrono::duration_cast<ms>(std::chrono::steady_clock::now() - start).count());
 
@@ -235,17 +226,17 @@
         err = errno;
         EXPECT_EQ(-1, ret) << "Client socket " << i << " not closed";
         if (ret == -1) {
-            EXPECT_EQ(ECONNABORTED, errno)
+            // Since we're connected to ourselves, the error might be ECONNABORTED (if we destroyed
+            // the socket) or ECONNRESET (if the other end was destroyed and sent a RST).
+            EXPECT_TRUE(errno == ECONNABORTED || errno == ECONNRESET)
                 << "Client socket: unexpected error: " << strerror(errno);
         }
 
-        // Check that the server sockets have been closed too (because closing the client sockets
-        // sends RSTs).
         ret = send(serversockets[i], "foo", sizeof("foo"), 0);
         err = errno;
         EXPECT_EQ(-1, ret) << "Server socket " << i << " not closed";
         if (ret == -1) {
-            EXPECT_EQ(ECONNRESET, errno)
+            EXPECT_TRUE(errno == ECONNABORTED || errno == ECONNRESET)
                 << "Server socket: unexpected error: " << strerror(errno);
         }
     }