Add a generic netlink dump function and use it in SockDiag.
Test: bullhead builds, boots
Test: unit tests pass
Bug: 34873832
Change-Id: I8479a8c24277855b54a11a327618426678c8d360
diff --git a/server/SockDiag.cpp b/server/SockDiag.cpp
index 423ad3b..edb046c 100644
--- a/server/SockDiag.cpp
+++ b/server/SockDiag.cpp
@@ -203,39 +203,16 @@
return sendDumpRequest(proto, family, states, iov, ARRAY_SIZE(iov));
}
-int SockDiag::readDiagMsg(uint8_t proto, const SockDiag::DumpCallback& callback) {
- char buf[kBufferSize];
-
- ssize_t bytesread;
- do {
- bytesread = read(mSock, buf, sizeof(buf));
-
- if (bytesread < 0) {
- return -errno;
+int SockDiag::readDiagMsg(uint8_t proto, const SockDiag::DestroyFilter& shouldDestroy) {
+ NetlinkDumpCallback callback = [this, proto, shouldDestroy] (nlmsghdr *nlh) {
+ if (nlh == nullptr) return;
+ const inet_diag_msg *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
+ if (shouldDestroy(proto, msg)) {
+ sockDestroy(proto, msg);
}
+ };
- uint32_t len = bytesread;
- for (nlmsghdr *nlh = reinterpret_cast<nlmsghdr *>(buf);
- NLMSG_OK(nlh, len);
- nlh = NLMSG_NEXT(nlh, len)) {
- switch (nlh->nlmsg_type) {
- case NLMSG_DONE:
- callback(proto, NULL);
- return 0;
- case NLMSG_ERROR: {
- nlmsgerr *err = reinterpret_cast<nlmsgerr *>(NLMSG_DATA(nlh));
- return err->error;
- }
- default:
- inet_diag_msg *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
- if (callback(proto, msg)) {
- sockDestroy(proto, msg);
- }
- }
- }
- } while (bytesread > 0);
-
- return 0;
+ return processNetlinkDump(mSock, callback);
}
// Determines whether a socket is a loopback socket. Does not check socket state.
@@ -324,7 +301,7 @@
return mSocketsDestroyed;
}
-int SockDiag::destroyLiveSockets(DumpCallback destroyFilter, const char *what,
+int SockDiag::destroyLiveSockets(DestroyFilter destroyFilter, const char *what,
iovec *iov, int iovcnt) {
int proto = IPPROTO_TCP;
diff --git a/server/SockDiag.h b/server/SockDiag.h
index 5e545eb..7af8152 100644
--- a/server/SockDiag.h
+++ b/server/SockDiag.h
@@ -27,6 +27,7 @@
#include <functional>
#include <set>
+#include "NetlinkCommands.h"
#include "Permission.h"
#include "UidRanges.h"
@@ -43,7 +44,7 @@
// Callback function that is called once for every socket in the dump. A return value of true
// means destroy the socket.
- typedef std::function<bool(uint8_t proto, const inet_diag_msg *)> DumpCallback;
+ typedef std::function<bool(uint8_t proto, const inet_diag_msg *)> DestroyFilter;
struct DestroyRequest {
nlmsghdr nlh;
@@ -56,7 +57,8 @@
int sendDumpRequest(uint8_t proto, uint8_t family, uint32_t states);
int sendDumpRequest(uint8_t proto, uint8_t family, const char *addrstr);
- int readDiagMsg(uint8_t proto, const DumpCallback& callback);
+ int readDiagMsg(uint8_t proto, const DestroyFilter& callback);
+
int sockDestroy(uint8_t proto, const inet_diag_msg *);
// Destroys all sockets on the given IPv4 or IPv6 address.
int destroySockets(const char *addrstr);
@@ -77,7 +79,7 @@
int mSocketsDestroyed;
int sendDumpRequest(uint8_t proto, uint8_t family, uint32_t states, iovec *iov, int iovcnt);
int destroySockets(uint8_t proto, int family, const char *addrstr);
- int destroyLiveSockets(DumpCallback destroy, const char *what, iovec *iov, int iovcnt);
+ int destroyLiveSockets(DestroyFilter destroy, const char *what, iovec *iov, int iovcnt);
bool hasSocks() { return mSock != -1 && mWriteSock != -1; }
void closeSocks() { close(mSock); close(mWriteSock); mSock = mWriteSock = -1; }
static bool isLoopbackSocket(const inet_diag_msg *msg);