Add a generic netlink dump function and use it in SockDiag.

Test: bullhead builds, boots
Test: unit tests pass
Bug: 34873832
Change-Id: I8479a8c24277855b54a11a327618426678c8d360
diff --git a/server/SockDiag.cpp b/server/SockDiag.cpp
index 423ad3b..edb046c 100644
--- a/server/SockDiag.cpp
+++ b/server/SockDiag.cpp
@@ -203,39 +203,16 @@
     return sendDumpRequest(proto, family, states, iov, ARRAY_SIZE(iov));
 }
 
-int SockDiag::readDiagMsg(uint8_t proto, const SockDiag::DumpCallback& callback) {
-    char buf[kBufferSize];
-
-    ssize_t bytesread;
-    do {
-        bytesread = read(mSock, buf, sizeof(buf));
-
-        if (bytesread < 0) {
-            return -errno;
+int SockDiag::readDiagMsg(uint8_t proto, const SockDiag::DestroyFilter& shouldDestroy) {
+    NetlinkDumpCallback callback = [this, proto, shouldDestroy] (nlmsghdr *nlh) {
+        if (nlh == nullptr) return;
+        const inet_diag_msg *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
+        if (shouldDestroy(proto, msg)) {
+            sockDestroy(proto, msg);
         }
+    };
 
-        uint32_t len = bytesread;
-        for (nlmsghdr *nlh = reinterpret_cast<nlmsghdr *>(buf);
-             NLMSG_OK(nlh, len);
-             nlh = NLMSG_NEXT(nlh, len)) {
-            switch (nlh->nlmsg_type) {
-              case NLMSG_DONE:
-                callback(proto, NULL);
-                return 0;
-              case NLMSG_ERROR: {
-                nlmsgerr *err = reinterpret_cast<nlmsgerr *>(NLMSG_DATA(nlh));
-                return err->error;
-              }
-              default:
-                inet_diag_msg *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
-                if (callback(proto, msg)) {
-                    sockDestroy(proto, msg);
-                }
-            }
-        }
-    } while (bytesread > 0);
-
-    return 0;
+    return processNetlinkDump(mSock, callback);
 }
 
 // Determines whether a socket is a loopback socket. Does not check socket state.
@@ -324,7 +301,7 @@
     return mSocketsDestroyed;
 }
 
-int SockDiag::destroyLiveSockets(DumpCallback destroyFilter, const char *what,
+int SockDiag::destroyLiveSockets(DestroyFilter destroyFilter, const char *what,
                                  iovec *iov, int iovcnt) {
     int proto = IPPROTO_TCP;
 
diff --git a/server/SockDiag.h b/server/SockDiag.h
index 5e545eb..7af8152 100644
--- a/server/SockDiag.h
+++ b/server/SockDiag.h
@@ -27,6 +27,7 @@
 #include <functional>
 #include <set>
 
+#include "NetlinkCommands.h"
 #include "Permission.h"
 #include "UidRanges.h"
 
@@ -43,7 +44,7 @@
 
     // Callback function that is called once for every socket in the dump. A return value of true
     // means destroy the socket.
-    typedef std::function<bool(uint8_t proto, const inet_diag_msg *)> DumpCallback;
+    typedef std::function<bool(uint8_t proto, const inet_diag_msg *)> DestroyFilter;
 
     struct DestroyRequest {
         nlmsghdr nlh;
@@ -56,7 +57,8 @@
 
     int sendDumpRequest(uint8_t proto, uint8_t family, uint32_t states);
     int sendDumpRequest(uint8_t proto, uint8_t family, const char *addrstr);
-    int readDiagMsg(uint8_t proto, const DumpCallback& callback);
+    int readDiagMsg(uint8_t proto, const DestroyFilter& callback);
+
     int sockDestroy(uint8_t proto, const inet_diag_msg *);
     // Destroys all sockets on the given IPv4 or IPv6 address.
     int destroySockets(const char *addrstr);
@@ -77,7 +79,7 @@
     int mSocketsDestroyed;
     int sendDumpRequest(uint8_t proto, uint8_t family, uint32_t states, iovec *iov, int iovcnt);
     int destroySockets(uint8_t proto, int family, const char *addrstr);
-    int destroyLiveSockets(DumpCallback destroy, const char *what, iovec *iov, int iovcnt);
+    int destroyLiveSockets(DestroyFilter destroy, const char *what, iovec *iov, int iovcnt);
     bool hasSocks() { return mSock != -1 && mWriteSock != -1; }
     void closeSocks() { close(mSock); close(mWriteSock); mSock = mWriteSock = -1; }
     static bool isLoopbackSocket(const inet_diag_msg *msg);