Add functional tests for tcp_nuke_addr.

1. Test for the hash table off-by-one error.
2. Test that sockets are closed and read() calls are interrupted
   with ETIMEDOUT.

Change-Id: Id8cdd50f1f5447734c230341f73d71fcfecdddd8
diff --git a/net/test/tcp_nuke_addr_test.py b/net/test/tcp_nuke_addr_test.py
index acac3fe..a0f44d3 100755
--- a/net/test/tcp_nuke_addr_test.py
+++ b/net/test/tcp_nuke_addr_test.py
@@ -15,7 +15,9 @@
 # limitations under the License.
 
 import contextlib
+import errno
 import fcntl
+import resource
 import os
 from socket import *  # pylint: disable=wildcard-import
 import struct
@@ -23,6 +25,7 @@
 import time
 import unittest
 
+import net_test
 
 IPV4_LOOPBACK_ADDR = '127.0.0.1'
 IPV6_LOOPBACK_ADDR = '::1'
@@ -32,6 +35,9 @@
 DEFAULT_TCP_PORT = 8001
 DEFAULT_BUFFER_SIZE = 20
 DEFAULT_TEST_MESSAGE = "TCP NUKE ADDR TEST"
+DEFAULT_TEST_RUNS = 100
+HASH_TEST_RUNS = 8000
+HASH_TEST_NOFILE = 16384
 
 
 @contextlib.contextmanager
@@ -92,30 +98,108 @@
   Raises:
     ValueError: If the address family is invalid for the ioctl.
   """
-  if addr_family == socket.AF_INET6:
+  if addr_family == AF_INET6:
     ifreq = struct.pack('BBBBBBBBBBBBBBBBIi',
                         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
                         128, 1)
-  elif addr_family == socket.AF_INET:
+  elif addr_family == AF_INET:
     raise NotImplementedError('Support for IPv4 not implemented yet.')
   else:
     raise ValueError('Address family %r not supported.' % addr_family)
-  datagram_socket = socket.socket(addr_family, socket.SOCK_DGRAM)
+  datagram_socket = socket(addr_family, SOCK_DGRAM)
   fcntl.ioctl(datagram_socket.fileno(), SIOCKILLADDR, ifreq)
   datagram_socket.close()
 
 
-class TcpNukeAddrTest(unittest.TestCase):
+class ExceptionalReadThread(threading.Thread):
 
-  def testIPv6KillAddr(self):
+  def __init__(self, sock):
+    self.sock = sock
+    self.exception = None
+    super(ExceptionalReadThread, self).__init__()
+    self.daemon = True
+
+  def run(self):
+    try:
+      read = self.sock.recv(4096)
+    except Exception, e:
+      self.exception = e
+
+
+def CreateSocketPair():
+  clientsock = socket(AF_INET6, SOCK_STREAM, 0)
+  listensock = socket(AF_INET6, SOCK_STREAM, 0)
+  listensock.bind((IPV6_LOOPBACK_ADDR, 0))
+  addr = listensock.getsockname()
+  listensock.listen(1)
+  clientsock.connect(addr)
+  acceptedsock, _ = listensock.accept()
+  return clientsock, acceptedsock
+
+
+class TcpNukeAddrTest(net_test.NetworkTest):
+
+  def testTimewaitSockets(self):
     """Tests that SIOCKILLADDR works as expected.
 
     Relevant kernel commits:
       https://www.codeaurora.org/cgit/quic/la/kernel/msm-3.18/commit/net/ipv4/tcp.c?h=aosp/android-3.10&id=1dcd3a1fa2fe78251cc91700eb1d384ab02e2dd6
     """
-    ExchangeMessage(socket.AF_INET6, IPV6_LOOPBACK_ADDR, DEFAULT_TCP_PORT)
-    KillAddrIoctl(socket.AF_INET6)
-    # Test passes if kernel does not crash.
+    for i in xrange(DEFAULT_TEST_RUNS):
+      ExchangeMessage(AF_INET6, IPV6_LOOPBACK_ADDR)
+      KillAddrIoctl(AF_INET6)
+      # Test passes if kernel does not crash.
+
+  def testClosesSockets(self):
+    """Tests that SIOCKILLADDR closes IPv6 sockets."""
+
+    threadpairs = []
+
+    for i in xrange(DEFAULT_TEST_RUNS):
+      clientsock, acceptedsock = CreateSocketPair()
+      clientthread = ExceptionalReadThread(clientsock)
+      clientthread.start()
+      serverthread = ExceptionalReadThread(acceptedsock)
+      serverthread.start()
+      threadpairs.append((clientthread, serverthread))
+
+    KillAddrIoctl(AF_INET6)
+
+    def CheckThreadException(thread):
+      thread.join(100)
+      self.assertFalse(thread.is_alive())
+      self.assertIsNotNone(thread.exception)
+      self.assertTrue(isinstance(thread.exception, IOError))
+      self.assertEquals(errno.ETIMEDOUT, thread.exception.errno)
+      self.assertRaisesErrno(errno.ENOTCONN, thread.sock.getpeername)
+      self.assertRaisesErrno(errno.EISCONN, thread.sock.connect, ("::1", 53))
+      self.assertRaisesErrno(errno.EPIPE, thread.sock.send, "foo")
+
+    for clientthread, serverthread in threadpairs:
+      CheckThreadException(clientthread)
+      CheckThreadException(serverthread)
+
+
+class TcpNukeAddrHashTest(net_test.NetworkTest):
+
+  def setUp(self):
+    self.nofile = resource.getrlimit(resource.RLIMIT_NOFILE)
+    resource.setrlimit(resource.RLIMIT_NOFILE, (HASH_TEST_NOFILE,
+                                                HASH_TEST_NOFILE))
+
+  def tearDown(self):
+    resource.setrlimit(resource.RLIMIT_NOFILE, self.nofile)
+
+  def testClosesAllSockets(self):
+    socketpairs = []
+    for i in xrange(HASH_TEST_RUNS):
+      socketpairs.append(CreateSocketPair())
+
+    KillAddrIoctl(AF_INET6)
+
+    for socketpair in socketpairs:
+      for sock in socketpair:
+        self.assertRaisesErrno(errno.ENOTCONN, sock.getpeername)
 
 
 if __name__ == "__main__":