Add support for the sock_diag netlink interface.

Change-Id: Id5b1b3516d0a708bcfd69ae0e182dc39fe225934
diff --git a/net/test/netlink.py b/net/test/netlink.py
index 6b2c60d..514ad08 100644
--- a/net/test/netlink.py
+++ b/net/test/netlink.py
@@ -121,9 +121,10 @@
       # If it's an attribute we know about, try to decode it.
       nla_name, nla_data = self._Decode(command, msg, nla.nla_type, nla_data)
 
-      # We only support unique attributes for now.
-      if nla_name in attributes:
-        raise ValueError("Duplicate attribute %d" % nla_name)
+      # We only support unique attributes for now, except for INET_DIAG_NONE,
+      # which can appear more than once but doesn't seem to contain any data.
+      if nla_name in attributes and nla_name != "INET_DIAG_NONE":
+        raise ValueError("Duplicate attribute %s" % nla_name)
 
       attributes[nla_name] = nla_data
       self._Debug("      %s" % str((nla_name, nla_data)))
diff --git a/net/test/sock_diag.py b/net/test/sock_diag.py
new file mode 100755
index 0000000..8c70eb3
--- /dev/null
+++ b/net/test/sock_diag.py
@@ -0,0 +1,226 @@
+#!/usr/bin/python
+#
+# Copyright 2015 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Partial Python implementation of sock_diag functionality."""
+
+# pylint: disable=g-bad-todo
+
+import errno
+from socket import *  # pylint: disable=wildcard-import
+
+import cstruct
+import net_test
+import netlink
+
+### Base netlink constants. See include/uapi/linux/netlink.h.
+NETLINK_SOCK_DIAG = 4
+
+### sock_diag constants. See include/uapi/linux/sock_diag.h.
+# Message types.
+SOCK_DIAG_BY_FAMILY = 20
+SOCK_DESTROY = 21
+
+### inet_diag_constants. See include/uapi/linux/inet_diag.h
+# Message types.
+TCPDIAG_GETSOCK = 18
+
+# Extensions.
+INET_DIAG_NONE = 0
+INET_DIAG_MEMINFO = 1
+INET_DIAG_INFO = 2
+INET_DIAG_VEGASINFO = 3
+INET_DIAG_CONG = 4
+INET_DIAG_TOS = 5
+INET_DIAG_TCLASS = 6
+INET_DIAG_SKMEMINFO = 7
+INET_DIAG_SHUTDOWN = 8
+INET_DIAG_DCTCPINFO = 9
+
+# Data structure formats.
+# These aren't constants, they're classes. So, pylint: disable=invalid-name
+InetDiagSockId = cstruct.Struct(
+    "InetDiagSockId", "!HH16s16sI8s", "sport dport src dst iface cookie")
+InetDiagReqV2 = cstruct.Struct(
+    "InetDiagReqV2", "=BBBxIS", "family protocol ext states id",
+    [InetDiagSockId])
+InetDiagMsg = cstruct.Struct(
+    "InetDiagMsg", "=BBBBSLLLLL",
+    "family state timer retrans id expires rqueue wqueue uid inode",
+    [InetDiagSockId])
+InetDiagMeminfo = cstruct.Struct(
+    "InetDiagMeminfo", "=IIII", "rmem wmem fmem tmem")
+
+SkMeminfo = cstruct.Struct(
+    "SkMeminfo", "=IIIIIIII",
+    "rmem_alloc rcvbuf wmem_alloc sndbuf fwd_alloc wmem_queued optmem backlog")
+TcpInfo = cstruct.Struct(
+    "TcpInfo", "=BBBBBBBxIIIIIIIIIIIIIIIIIIIIIIII",
+    "state ca_state retransmits probes backoff options wscale "
+    "rto ato snd_mss rcv_mss "
+    "unacked sacked lost retrans fackets "
+    "last_data_sent last_ack_sent last_data_recv last_ack_recv "
+    "pmtu rcv_ssthresh rtt rttvar snd_ssthresh snd_cwnd advmss reordering "
+    "rcv_rtt rcv_space "
+    "total_retrans")  # As of linux 3.13, at least.
+
+# TCP states. See include/net/tcp_states.h.
+TCP_ESTABLISHED = 1
+TCP_SYN_SENT = 2
+TCP_SYN_RECV = 3
+TCP_FIN_WAIT1 = 4
+TCP_FIN_WAIT2 = 5
+TCP_TIME_WAIT = 6
+TCP_CLOSE = 7
+TCP_CLOSE_WAIT = 8
+TCP_LAST_ACK = 9
+TCP_LISTEN = 10
+TCP_CLOSING = 11
+TCP_NEW_SYN_RECV = 12
+
+
+class SockDiag(netlink.NetlinkSocket):
+
+  FAMILY = NETLINK_SOCK_DIAG
+  NL_DEBUG = []
+
+  def _Decode(self, command, msg, nla_type, nla_data):
+    """Decodes netlink attributes to Python types."""
+    if msg.family == AF_INET or msg.family == AF_INET6:
+      name = self._GetConstantName(__name__, nla_type, "INET_DIAG")
+    else:
+      # Don't know what this is. Leave it as an integer.
+      name = nla_type
+
+    if name in ["INET_DIAG_SHUTDOWN", "INET_DIAG_TOS", "INET_DIAG_TCLASS"]:
+      data = ord(nla_data)
+    elif name == "INET_DIAG_CONG":
+      data = nla_data.strip("\x00")
+    elif name == "INET_DIAG_MEMINFO":
+      data = InetDiagMeminfo(nla_data)
+    elif name == "INET_DIAG_INFO":
+      # TODO: Catch the exception and try something else if it's not TCP.
+      data = TcpInfo(nla_data)
+    elif name == "INET_DIAG_SKMEMINFO":
+      data = SkMeminfo(nla_data)
+    else:
+      data = nla_data
+
+    return name, data
+
+  def MaybeDebugCommand(self, command, data):
+    name = self._GetConstantName(__name__, command, "SOCK_")
+    if "ALL" not in self.NL_DEBUG and "SOCK" not in self.NL_DEBUG:
+      return
+    parsed = self._ParseNLMsg(data, InetDiagReqV2)
+    print "%s %s" % (name, str(parsed))
+
+  @staticmethod
+  def _EmptyInetDiagSockId():
+    return InetDiagSockId(("\x00" * len(InetDiagSockId)))
+
+  def DumpSockets(self, family, protocol, ext, states, sock_id):
+    """Dumps sockets matching the specified parameters."""
+    if sock_id is None:
+      sock_id = self._EmptyInetDiagSockId()
+
+    diag_req = InetDiagReqV2((family, protocol, ext, states, sock_id))
+    return self._Dump(SOCK_DIAG_BY_FAMILY, diag_req, InetDiagMsg)
+
+  def DumpAllInetSockets(self, protocol, sock_id=None, ext=0, states=0xffffffff):
+    # DumpSockets(AF_UNSPEC) does not result in dumping all inet sockets, it
+    # results in ENOENT.
+    sockets = []
+    for family in [AF_INET, AF_INET6]:
+      sockets += self.DumpSockets(family, protocol, ext, states, None)
+    return sockets
+
+  @staticmethod
+  def GetRawAddress(family, addr):
+    """Fetches the source address from an InetDiagMsg."""
+    addrlen = {AF_INET:4, AF_INET6: 16}[family]
+    return inet_ntop(family, addr[:addrlen])
+
+  @staticmethod
+  def GetSourceAddress(diag_msg):
+    """Fetches the source address from an InetDiagMsg."""
+    return SockDiag.GetRawAddress(diag_msg.family, diag_msg.id.src)
+
+  @staticmethod
+  def GetDestinationAddress(diag_msg):
+    """Fetches the source address from an InetDiagMsg."""
+    return SockDiag.GetRawAddress(diag_msg.family, diag_msg.id.dst)
+
+  @staticmethod
+  def RawAddress(addr):
+    """Converts an IP address string to binary format."""
+    family = AF_INET6 if ":" in addr else AF_INET
+    return inet_pton(family, addr)
+
+  @staticmethod
+  def PaddedAddress(addr):
+    """Converts an IP address string to binary format for InetDiagSockId."""
+    padded = SockDiag.RawAddress(addr)
+    if len(padded) < 16:
+      padded += "\x00" * (16 - len(padded))
+    return padded
+
+  @staticmethod
+  def DiagReqFromSocket(s):
+    """Creates an InetDiagReqV2 that matches the specified socket."""
+    family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN)
+    protocol = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_PROTOCOL)
+    iface = s.getsockopt(SOL_SOCKET, net_test.SO_BINDTODEVICE)
+    src, sport = s.getsockname()[:2]
+    try:
+      dst, dport = s.getpeername()[:2]
+    except error, e:
+      if e.errno == errno.ENOTCONN:
+        dport = 0
+        dst = "::" if family == AF_INET6 else "0.0.0.0"
+      else:
+        raise e
+    src = SockDiag.PaddedAddress(src)
+    dst = SockDiag.PaddedAddress(dst)
+    sock_id = InetDiagSockId((sport, dport, src, dst, iface, "\x00" * 8))
+    return InetDiagReqV2((family, protocol, 0, 0xffffffff, sock_id))
+
+  def GetSockDiagForFd(self, s):
+    """Gets an InetDiagMsg from the kernel for the specified socket."""
+    req = self.DiagReqFromSocket(s)
+    for diag_msg, attrs in self._Dump(SOCK_DIAG_BY_FAMILY, req, InetDiagMsg):
+      return diag_msg
+    raise ValueError("Dump of %s returned no sockets" % req)
+
+  def GetSockDiag(self, family, protocol, sock_id, ext=0, states=0xffffffff):
+    """Gets an InetDiagMsg from the kernel for the specified parameters."""
+    req = InetDiagReqV2((family, protocol, ext, states, sock_id))
+    self._SendNlRequest(SOCK_DIAG_BY_FAMILY, req.Pack(), netlink.NLM_F_REQUEST)
+    data = self._Recv()
+    return self._ParseNLMsg(data, InetDiagMsg)[0]
+
+
+if __name__ == "__main__":
+  n = SockDiag()
+  n.DEBUG = True
+  sock_id = n._EmptyInetDiagSockId()
+  sock_id.dport = 443
+  family = AF_INET6
+  protocol = IPPROTO_TCP
+  ext = 0
+  states = 0xffffffff
+  ext = 1 << (INET_DIAG_TOS - 1) | 1 << (INET_DIAG_TCLASS - 1)
+  diag_msgs = n.DumpSockets(family, protocol, ext, states, sock_id)
+  print diag_msgs
diff --git a/net/test/sock_diag_test.py b/net/test/sock_diag_test.py
new file mode 100755
index 0000000..2803bd2
--- /dev/null
+++ b/net/test/sock_diag_test.py
@@ -0,0 +1,112 @@
+#!/usr/bin/python
+#
+# Copyright 2015 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import errno
+import random
+from socket import *
+import time
+import unittest
+
+import csocket
+import cstruct
+import multinetwork_base
+import net_test
+import packets
+import sock_diag
+
+
+NUM_SOCKETS = 100
+
+ALL_NON_TIME_WAIT = 0xffffffff & ~(1 << sock_diag.TCP_TIME_WAIT)
+
+
+class SockDiagTest(multinetwork_base.MultiNetworkBaseTest):
+
+  @staticmethod
+  def _CreateLotsOfSockets():
+    # Dict mapping (addr, sport, dport) tuples to socketpairs.
+    socketpairs = {}
+    for i in xrange(NUM_SOCKETS):
+      family, addr = random.choice([(AF_INET, "127.0.0.1"), (AF_INET6, "::1")])
+      socketpair = net_test.CreateSocketPair(family, SOCK_STREAM, addr)
+      sport, dport = (socketpair[0].getsockname()[1],
+                      socketpair[1].getsockname()[1])
+      socketpairs[(addr, sport, dport)] = socketpair
+    return socketpairs
+
+  def setUp(self):
+    self.sock_diag = sock_diag.SockDiag()
+    self.socketpairs = self._CreateLotsOfSockets()
+
+  def tearDown(self):
+    [s.close() for socketpair in self.socketpairs.values() for s in socketpair]
+
+  def assertSockDiagMatchesSocket(self, s, diag_msg):
+    family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN)
+    self.assertEqual(diag_msg.family, family)
+
+    # TODO: The kernel (at least 3.10) seems only to fill in the first 4 bytes
+    # of src and dst in the case of IPv4 addresses. This means we can't just do
+    # something like:
+    #  self.assertEqual(diag_msg.id.src, self.sock_diag.PaddedAddress(src))
+    # because the trailing bytes might not match.
+    # This seems like a bug because it might leaks kernel memory contents, but
+    # regardless, work around that here.
+    addrlen = {AF_INET: 4, AF_INET6: 16}[family]
+
+    src, sport = s.getsockname()[0:2]
+    self.assertEqual(diag_msg.id.sport, sport)
+    self.assertEqual(diag_msg.id.src[:addrlen],
+                     self.sock_diag.RawAddress(src))
+
+    if self.sock_diag.GetDestinationAddress(diag_msg) not in ["0.0.0.0", "::"]:
+      dst, dport = s.getpeername()[0:2]
+      self.assertEqual(diag_msg.id.dst[:addrlen],
+                       self.sock_diag.RawAddress(dst))
+      self.assertEqual(diag_msg.id.dport, dport)
+    else:
+      assertRaisesErrno(errno.ENOTCONN, s.getpeername)
+
+  def testFindsAllMySockets(self):
+    sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
+                                                states=ALL_NON_TIME_WAIT)
+    self.assertGreaterEqual(len(sockets), NUM_SOCKETS)
+
+    # Find the cookies for all of our sockets.
+    cookies = {}
+    for diag_msg, attrs in sockets:
+      addr = self.sock_diag.GetSourceAddress(diag_msg)
+      sport = diag_msg.id.sport
+      dport = diag_msg.id.dport
+      if (addr, sport, dport) in self.socketpairs:
+        cookies[(addr, sport, dport)] = diag_msg.id.cookie
+      elif (addr, dport, sport) in self.socketpairs:
+        cookies[(addr, sport, dport)] = diag_msg.id.cookie
+
+    # Did we find all the cookies?
+    self.assertEquals(2 * NUM_SOCKETS, len(cookies))
+
+    socketpairs = self.socketpairs.values()
+    random.shuffle(socketpairs)
+    for socketpair in socketpairs:
+      for sock in socketpair:
+        self.assertSockDiagMatchesSocket(
+            sock,
+            self.sock_diag.GetSockDiagForFd(sock))
+
+
+if __name__ == "__main__":
+  unittest.main()