Use SOCK_DESTROY in netd.
Bug: 26976388
(cherry picked from commit f32fc598b01ba8d59873b0a1085716fd84678b54)
Change-Id: I2e4d0018fdcee7106fc083a522d81dba87a4db40
diff --git a/server/NetlinkHandler.cpp b/server/NetlinkHandler.cpp
index 97dc3e0..718fbdb 100644
--- a/server/NetlinkHandler.cpp
+++ b/server/NetlinkHandler.cpp
@@ -29,6 +29,7 @@
#include "NetlinkHandler.h"
#include "NetlinkManager.h"
#include "ResponseCode.h"
+#include "SockDiag.h"
static const char *kUpdated = "updated";
static const char *kRemoved = "removed";
@@ -78,14 +79,36 @@
const char *flags = evt->findParam("FLAGS");
const char *scope = evt->findParam("SCOPE");
if (action == NetlinkEvent::Action::kAddressRemoved && iface && address) {
- int resetMask = strchr(address, ':') ? RESET_IPV6_ADDRESSES : RESET_IPV4_ADDRESSES;
- resetMask |= RESET_IGNORE_INTERFACE_ADDRESS;
- if (int ret = ifc_reset_connections(iface, resetMask)) {
- ALOGE("ifc_reset_connections failed on iface %s for address %s (%s)", iface,
- address, strerror(ret));
+ // Note: if this interface was deleted, iface is "" and we don't notify.
+ SockDiag sd;
+ if (sd.open()) {
+ char addrstr[INET6_ADDRSTRLEN];
+ strncpy(addrstr, address, sizeof(addrstr));
+ char *slash = strchr(addrstr, '/');
+ if (slash) {
+ *slash = '\0';
+ }
+
+ int ret = sd.destroySockets(addrstr);
+ if (ret < 0) {
+ ALOGE("Error destroying sockets: %s", strerror(ret));
+ }
+ } else {
+ ALOGE("Error opening NETLINK_SOCK_DIAG socket: %s", strerror(errno));
+ }
+
+ // TODO: delete this once SOCK_DESTROY works everywhere.
+ if (iface[0]) {
+ int resetMask = strchr(address, ':') ?
+ RESET_IPV6_ADDRESSES : RESET_IPV4_ADDRESSES;
+ resetMask |= RESET_IGNORE_INTERFACE_ADDRESS;
+ if (int ret = ifc_reset_connections(iface, resetMask)) {
+ ALOGE("ifc_reset_connections failed on iface %s for address %s (%s)", iface,
+ address, strerror(ret));
+ }
}
}
- if (iface && flags && scope) {
+ if (iface && iface[0] && address && flags && scope) {
notifyAddressChanged(action, address, iface, flags, scope);
}
} else if (action == NetlinkEvent::Action::kRdnss) {
diff --git a/server/SockDiag.cpp b/server/SockDiag.cpp
index 2f1437c..b9f69cd 100644
--- a/server/SockDiag.cpp
+++ b/server/SockDiag.cpp
@@ -33,6 +33,8 @@
#include "NetdConstants.h"
#include "SockDiag.h"
+#include <chrono>
+
#ifndef SOCK_DESTROY
#define SOCK_DESTROY 21
#endif
@@ -208,6 +210,10 @@
}
int SockDiag::sockDestroy(uint8_t proto, const inet_diag_msg *msg) {
+ if (msg == nullptr) {
+ return 0;
+ }
+
DestroyRequest request = {
.nlh = {
.nlmsg_type = SOCK_DESTROY,
@@ -226,5 +232,47 @@
return -errno;
}
- return checkError(mWriteSock);
+ int ret = checkError(mWriteSock);
+ if (!ret) mSocketsDestroyed++;
+ return ret;
+}
+
+int SockDiag::destroySockets(uint8_t proto, int family, const char *addrstr) {
+ if (!hasSocks()) {
+ return -EBADFD;
+ }
+
+ if (int ret = sendDumpRequest(proto, family, addrstr)) {
+ return ret;
+ }
+
+ auto destroy = [this] (uint8_t proto, const inet_diag_msg *msg) {
+ return this->sockDestroy(proto, msg);
+ };
+
+ return readDiagMsg(proto, destroy);
+}
+
+int SockDiag::destroySockets(const char *addrstr) {
+ using ms = std::chrono::duration<float, std::ratio<1, 1000>>;
+
+ mSocketsDestroyed = 0;
+ const auto start = std::chrono::steady_clock::now();
+ if (!strchr(addrstr, ':')) {
+ if (int ret = destroySockets(IPPROTO_TCP, AF_INET, addrstr)) {
+ ALOGE("Failed to destroy IPv4 sockets on %s: %s", addrstr, strerror(-ret));
+ return ret;
+ }
+ }
+ if (int ret = destroySockets(IPPROTO_TCP, AF_INET6, addrstr)) {
+ ALOGE("Failed to destroy IPv6 sockets on %s: %s", addrstr, strerror(-ret));
+ return ret;
+ }
+ auto elapsed = std::chrono::duration_cast<ms>(std::chrono::steady_clock::now() - start);
+
+ if (mSocketsDestroyed > 0) {
+ ALOGI("Destroyed %d sockets on %s in %.1f ms", mSocketsDestroyed, addrstr, elapsed.count());
+ }
+
+ return mSocketsDestroyed;
}
diff --git a/server/SockDiag.h b/server/SockDiag.h
index 3b6ca8b..56acbdb 100644
--- a/server/SockDiag.h
+++ b/server/SockDiag.h
@@ -5,6 +5,7 @@
#include <linux/inet_diag.h>
struct inet_diag_msg;
+class SockDiagTest;
class SockDiag {
@@ -17,17 +18,20 @@
inet_diag_req_v2 req;
} __attribute__((__packed__));
- SockDiag() : mSock(-1), mWriteSock(-1) {}
+ SockDiag() : mSock(-1), mWriteSock(-1), mSocketsDestroyed(0) {}
bool open();
virtual ~SockDiag() { closeSocks(); }
int sendDumpRequest(uint8_t proto, uint8_t family, const char *addrstr);
int readDiagMsg(uint8_t proto, DumpCallback callback);
int sockDestroy(uint8_t proto, const inet_diag_msg *);
+ int destroySockets(const char *addrstr);
private:
int mSock;
int mWriteSock;
+ int mSocketsDestroyed;
+ int destroySockets(uint8_t proto, int family, const char *addrstr);
bool hasSocks() { return mSock != -1 && mWriteSock != -1; }
void closeSocks() { close(mSock); close(mWriteSock); mSock = mWriteSock = -1; }
};
diff --git a/tests/sock_diag_test.cpp b/tests/sock_diag_test.cpp
index 67af978..8ee9908 100644
--- a/tests/sock_diag_test.cpp
+++ b/tests/sock_diag_test.cpp
@@ -213,18 +213,9 @@
SockDiag sd;
ASSERT_TRUE(sd.open()) << "Failed to open SOCK_DIAG socket";
- int ret = sd.sendDumpRequest(IPPROTO_TCP, AF_INET6, "::1");
- ASSERT_EQ(0, ret) << "Failed to send IPv6 dump request: " << strerror(-ret);
-
- auto closeMySockets = [&] (uint8_t proto, const inet_diag_msg *msg) {
- if (msg && msg->id.idiag_dport == htons(port)) {
- return sd.sockDestroy(proto, msg);
- }
- return 0;
- };
-
start = std::chrono::steady_clock::now();
- sd.readDiagMsg(IPPROTO_TCP, closeMySockets);
+ int ret = sd.destroySockets("::1");
+ EXPECT_LE(0, ret) << ": Failed to destroy sockets on ::1: " << strerror(-ret);
fprintf(stderr, " Destroying: %6.1f ms\n",
std::chrono::duration_cast<ms>(std::chrono::steady_clock::now() - start).count());
@@ -235,17 +226,17 @@
err = errno;
EXPECT_EQ(-1, ret) << "Client socket " << i << " not closed";
if (ret == -1) {
- EXPECT_EQ(ECONNABORTED, errno)
+ // Since we're connected to ourselves, the error might be ECONNABORTED (if we destroyed
+ // the socket) or ECONNRESET (if the other end was destroyed and sent a RST).
+ EXPECT_TRUE(errno == ECONNABORTED || errno == ECONNRESET)
<< "Client socket: unexpected error: " << strerror(errno);
}
- // Check that the server sockets have been closed too (because closing the client sockets
- // sends RSTs).
ret = send(serversockets[i], "foo", sizeof("foo"), 0);
err = errno;
EXPECT_EQ(-1, ret) << "Server socket " << i << " not closed";
if (ret == -1) {
- EXPECT_EQ(ECONNRESET, errno)
+ EXPECT_TRUE(errno == ECONNABORTED || errno == ECONNRESET)
<< "Server socket: unexpected error: " << strerror(errno);
}
}