Merge "Fix for simpleperf report on 'fugu'."
diff --git a/f2fs_utils/Android.mk b/f2fs_utils/Android.mk
index 0f39116..2bc190f 100644
--- a/f2fs_utils/Android.mk
+++ b/f2fs_utils/Android.mk
@@ -17,6 +17,7 @@
 LOCAL_SRC_FILES := f2fs_ioutils.c
 LOCAL_C_INCLUDES := external/f2fs-tools/include external/f2fs-tools/mkfs
 LOCAL_STATIC_LIBRARIES := \
+    libselinux \
     libsparse_host \
     libext2_uuid-host \
     libz
diff --git a/tests/net_test/iproute.py b/tests/net_test/iproute.py
index 7c9eec0..e0f2837 100644
--- a/tests/net_test/iproute.py
+++ b/tests/net_test/iproute.py
@@ -454,10 +454,7 @@
       # implement parsing dump results.
       raise NotImplementedError("IPv4 RTM_GETADDR not implemented.")
     self._Address(6, RTM_GETADDR, address, 0, 0, RT_SCOPE_UNIVERSE, ifindex)
-    data = self._Recv()
-    if NLMsgHdr(data).type == NLMSG_ERROR:
-      self._ParseAck(data)
-    return self._ParseNLMsg(data, IfAddrMsg)[0]
+    return self._GetMsg(IfAddrMsg)
 
   def _Route(self, version, command, table, dest, prefixlen, nexthop, dev,
              mark, uid):
diff --git a/tests/net_test/net_test.py b/tests/net_test/net_test.py
index f108aa8..d7ea013 100755
--- a/tests/net_test/net_test.py
+++ b/tests/net_test/net_test.py
@@ -56,6 +56,8 @@
 IPV6_FL_S_EXCL = 1
 IPV6_FL_S_ANY = 255
 
+IFNAMSIZ = 16
+
 IPV4_PING = "\x08\x00\x00\x00\x0a\xce\x00\x03"
 IPV6_PING = "\x80\x00\x00\x00\x0a\xce\x00\x03"
 
@@ -171,9 +173,9 @@
 
 def GetInterfaceIndex(ifname):
   s = IPv4PingSocket()
-  ifr = struct.pack("16si", ifname, 0)
+  ifr = struct.pack("%dsi" % IFNAMSIZ, ifname, 0)
   ifr = fcntl.ioctl(s, scapy.SIOCGIFINDEX, ifr)
-  return struct.unpack("16si", ifr)[1]
+  return struct.unpack("%dsi" % IFNAMSIZ, ifr)[1]
 
 
 def SetInterfaceHWAddr(ifname, hwaddr):
@@ -182,20 +184,20 @@
   hwaddr = hwaddr.decode("hex")
   if len(hwaddr) != 6:
     raise ValueError("Unknown hardware address length %d" % len(hwaddr))
-  ifr = struct.pack("16sH6s", ifname, scapy.ARPHDR_ETHER, hwaddr)
+  ifr = struct.pack("%dsH6s" % IFNAMSIZ, ifname, scapy.ARPHDR_ETHER, hwaddr)
   fcntl.ioctl(s, SIOCSIFHWADDR, ifr)
 
 
 def SetInterfaceState(ifname, up):
   s = IPv4PingSocket()
-  ifr = struct.pack("16sH", ifname, 0)
+  ifr = struct.pack("%dsH" % IFNAMSIZ, ifname, 0)
   ifr = fcntl.ioctl(s, scapy.SIOCGIFFLAGS, ifr)
-  _, flags = struct.unpack("16sH", ifr)
+  _, flags = struct.unpack("%dsH" % IFNAMSIZ, ifr)
   if up:
     flags |= scapy.IFF_UP
   else:
     flags &= ~scapy.IFF_UP
-  ifr = struct.pack("16sH", ifname, flags)
+  ifr = struct.pack("%dsH" % IFNAMSIZ, ifname, flags)
   ifr = fcntl.ioctl(s, scapy.SIOCSIFFLAGS, ifr)
 
 
diff --git a/tests/net_test/netlink.py b/tests/net_test/netlink.py
index 514ad08..f245901 100644
--- a/tests/net_test/netlink.py
+++ b/tests/net_test/netlink.py
@@ -200,6 +200,12 @@
     data = data[attrlen:]
     return (nlmsg, attributes), data
 
+  def _GetMsg(self, msgtype):
+    data = self._Recv()
+    if NLMsgHdr(data).type == NLMSG_ERROR:
+      self._ParseAck(data)
+    return self._ParseNLMsg(data, msgtype)[0]
+
   def _GetMsgList(self, msgtype, data, expect_done):
     out = []
     while data:
diff --git a/tests/net_test/packets.py b/tests/net_test/packets.py
index d92a97e..c02adc0 100644
--- a/tests/net_test/packets.py
+++ b/tests/net_test/packets.py
@@ -120,11 +120,12 @@
 def FIN(version, srcaddr, dstaddr, packet):
   ip = _GetIpLayer(version)
   original = packet.getlayer("TCP")
-  was_fin = (original.flags & TCP_FIN) != 0
+  was_syn_or_fin = (original.flags & (TCP_SYN | TCP_FIN)) != 0
+  ack_delta = was_syn_or_fin + len(original.payload)
   return ("TCP FIN",
           ip(src=srcaddr, dst=dstaddr) /
           scapy.TCP(sport=original.dport, dport=original.sport,
-                    ack=original.seq + was_fin, seq=original.ack,
+                    ack=original.seq + ack_delta, seq=original.ack,
                     flags=TCP_ACK | TCP_FIN, window=TCP_WINDOW))
 
 def GRE(version, srcaddr, dstaddr, proto, packet):
diff --git a/tests/net_test/run_net_test.sh b/tests/net_test/run_net_test.sh
index 85dc122..080aac7 100755
--- a/tests/net_test/run_net_test.sh
+++ b/tests/net_test/run_net_test.sh
@@ -1,7 +1,8 @@
 #!/bin/bash
 
 # Kernel configuration options.
-OPTIONS=" IPV6 IPV6_ROUTER_PREF IPV6_MULTIPLE_TABLES IPV6_ROUTE_INFO"
+OPTIONS=" DEBUG_SPINLOCK DEBUG_ATOMIC_SLEEP DEBUG_MUTEXES DEBUG_RT_MUTEXES"
+OPTIONS="$OPTIONS IPV6 IPV6_ROUTER_PREF IPV6_MULTIPLE_TABLES IPV6_ROUTE_INFO"
 OPTIONS="$OPTIONS TUN SYN_COOKIES IP_ADVANCED_ROUTER IP_MULTIPLE_TABLES"
 OPTIONS="$OPTIONS NETFILTER NETFILTER_ADVANCED NETFILTER_XTABLES"
 OPTIONS="$OPTIONS NETFILTER_XT_MARK NETFILTER_XT_TARGET_MARK"
@@ -11,6 +12,7 @@
 OPTIONS="$OPTIONS CONFIG_NETFILTER_XT_TARGET_NFLOG"
 OPTIONS="$OPTIONS CONFIG_NETFILTER_XT_MATCH_QUOTA CONFIG_NETFILTER_XT_MATCH_QUOTA2"
 OPTIONS="$OPTIONS CONFIG_NETFILTER_XT_MATCH_QUOTA2_LOG"
+OPTIONS="$OPTIONS CONFIG_INET_UDP_DIAG CONFIG_INET_DIAG_DESTROY"
 
 # For 3.1 kernels, where devtmpfs is not on by default.
 OPTIONS="$OPTIONS DEVTMPFS DEVTMPFS_MOUNT"
diff --git a/tests/net_test/sock_diag.py b/tests/net_test/sock_diag.py
index 8c70eb3..a9de345 100755
--- a/tests/net_test/sock_diag.py
+++ b/tests/net_test/sock_diag.py
@@ -131,13 +131,17 @@
   def _EmptyInetDiagSockId():
     return InetDiagSockId(("\x00" * len(InetDiagSockId)))
 
+  def Dump(self, diag_req):
+    out = self._Dump(SOCK_DIAG_BY_FAMILY, diag_req, InetDiagMsg)
+    return out
+
   def DumpSockets(self, family, protocol, ext, states, sock_id):
     """Dumps sockets matching the specified parameters."""
     if sock_id is None:
       sock_id = self._EmptyInetDiagSockId()
 
     diag_req = InetDiagReqV2((family, protocol, ext, states, sock_id))
-    return self._Dump(SOCK_DIAG_BY_FAMILY, diag_req, InetDiagMsg)
+    return self.Dump(diag_req)
 
   def DumpAllInetSockets(self, protocol, sock_id=None, ext=0, states=0xffffffff):
     # DumpSockets(AF_UNSPEC) does not result in dumping all inet sockets, it
@@ -177,12 +181,26 @@
       padded += "\x00" * (16 - len(padded))
     return padded
 
+  # For IPv4 addresses, the kernel seems only to fill in the first 4 bytes of
+  # src and dst, leaving the others unspecified. This seems like a bug because
+  # it might leak kernel memory contents, but regardless, work around it.
+  @staticmethod
+  def FixupDiagMsg(d):
+    if d.family == AF_INET:
+      d.id.src = d.id.src[:4] + "\x00" * 12
+      d.id.dst = d.id.dst[:4] + "\x00" * 12
+
   @staticmethod
   def DiagReqFromSocket(s):
     """Creates an InetDiagReqV2 that matches the specified socket."""
     family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN)
     protocol = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_PROTOCOL)
-    iface = s.getsockopt(SOL_SOCKET, net_test.SO_BINDTODEVICE)
+    if net_test.LINUX_VERSION >= (3, 8):
+      iface = s.getsockopt(SOL_SOCKET, net_test.SO_BINDTODEVICE,
+                           net_test.IFNAMSIZ)
+      iface = GetInterfaceIndex(iface) if iface else 0
+    else:
+      iface = 0
     src, sport = s.getsockname()[:2]
     try:
       dst, dport = s.getpeername()[:2]
@@ -197,19 +215,35 @@
     sock_id = InetDiagSockId((sport, dport, src, dst, iface, "\x00" * 8))
     return InetDiagReqV2((family, protocol, 0, 0xffffffff, sock_id))
 
-  def GetSockDiagForFd(self, s):
-    """Gets an InetDiagMsg from the kernel for the specified socket."""
-    req = self.DiagReqFromSocket(s)
-    for diag_msg, attrs in self._Dump(SOCK_DIAG_BY_FAMILY, req, InetDiagMsg):
+  def FindSockDiagFromReq(self, req):
+    for diag_msg, attrs in self.Dump(req):
       return diag_msg
     raise ValueError("Dump of %s returned no sockets" % req)
 
-  def GetSockDiag(self, family, protocol, sock_id, ext=0, states=0xffffffff):
-    """Gets an InetDiagMsg from the kernel for the specified parameters."""
-    req = InetDiagReqV2((family, protocol, ext, states, sock_id))
+  def FindSockDiagFromFd(self, s):
+    """Gets an InetDiagMsg from the kernel for the specified socket."""
+    req = self.DiagReqFromSocket(s)
+    return self.FindSockDiagFromReq(req)
+
+  def GetSockDiag(self, req):
+    """Gets an InetDiagMsg from the kernel for the specified request."""
     self._SendNlRequest(SOCK_DIAG_BY_FAMILY, req.Pack(), netlink.NLM_F_REQUEST)
-    data = self._Recv()
-    return self._ParseNLMsg(data, InetDiagMsg)[0]
+    return self._GetMsg(InetDiagMsg)[0]
+
+  @staticmethod
+  def DiagReqFromDiagMsg(d, protocol):
+    """Constructs a diag_req from a diag_msg the kernel has given us."""
+    return InetDiagReqV2((d.family, protocol, 0, 1 << d.state, d.id))
+
+  def CloseSocket(self, req):
+    self._SendNlRequest(SOCK_DESTROY, req.Pack(),
+                        netlink.NLM_F_REQUEST | netlink.NLM_F_ACK)
+
+  def CloseSocketFromFd(self, s):
+    diag_msg = self.FindSockDiagFromFd(s)
+    protocol = s.getsockopt(SOL_SOCKET, net_test.SO_PROTOCOL)
+    req = self.DiagReqFromDiagMsg(diag_msg, protocol)
+    return self.CloseSocket(req)
 
 
 if __name__ == "__main__":
diff --git a/tests/net_test/sock_diag_test.py b/tests/net_test/sock_diag_test.py
index 2803bd2..5975931 100755
--- a/tests/net_test/sock_diag_test.py
+++ b/tests/net_test/sock_diag_test.py
@@ -14,7 +14,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import errno
+from errno import *
+import os
 import random
 from socket import *
 import time
@@ -26,12 +27,15 @@
 import net_test
 import packets
 import sock_diag
+import threading
 
 
 NUM_SOCKETS = 100
 
 ALL_NON_TIME_WAIT = 0xffffffff & ~(1 << sock_diag.TCP_TIME_WAIT)
 
+# TODO: Backport SOCK_DESTROY and delete this.
+HAVE_SOCK_DESTROY = net_test.LINUX_VERSION >= (4, 4)
 
 class SockDiagTest(multinetwork_base.MultiNetworkBaseTest):
 
@@ -48,39 +52,67 @@
     return socketpairs
 
   def setUp(self):
+    super(SockDiagTest, self).setUp()
     self.sock_diag = sock_diag.SockDiag()
-    self.socketpairs = self._CreateLotsOfSockets()
+    self.socketpairs = {}
 
   def tearDown(self):
     [s.close() for socketpair in self.socketpairs.values() for s in socketpair]
+    super(SockDiagTest, self).tearDown()
+
+  def testFixupDiagMsg(self):
+    src = "0a00fa02303030312030312038302031"
+    dst = "0808080841414141414141416f0a3230"
+    cookie = "4078678100000000"
+    sockid = sock_diag.InetDiagSockId((47436, 32069,
+                                       src.decode("hex"), dst.decode("hex"), 0,
+                                       cookie.decode("hex")))
+    msg4 = sock_diag.InetDiagMsg((AF_INET, IPPROTO_TCP, 0,
+                                  sock_diag.TCP_SYN_RECV, sockid,
+                                  980, 123, 456, 789, 5555))
+    # Make a copy, cstructs are mutable.
+    msg6 = sock_diag.InetDiagMsg(msg4.Pack())
+    msg6.family = AF_INET6
+
+    fixed6 = sock_diag.InetDiagMsg(msg6.Pack())
+    self.sock_diag.FixupDiagMsg(fixed6)
+    self.assertEquals(msg6.Pack(), fixed6.Pack())
+
+    fixed4 = sock_diag.InetDiagMsg(msg4.Pack())
+    self.sock_diag.FixupDiagMsg(fixed4)
+    msg4.id.src = src.decode("hex")[:4] + 12 * "\x00"
+    msg4.id.dst = dst.decode("hex")[:4] + 12 * "\x00"
+    self.assertEquals(msg4.Pack(), fixed4.Pack())
+
+  def assertSocketClosed(self, sock):
+    self.assertRaisesErrno(ENOTCONN, sock.getpeername)
+
+  def assertSocketConnected(self, sock):
+    sock.getpeername()  # No errors? Socket is alive and connected.
+
+  def assertSocketsClosed(self, socketpair):
+    for sock in socketpair:
+      self.assertSocketClosed(sock)
 
   def assertSockDiagMatchesSocket(self, s, diag_msg):
     family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN)
     self.assertEqual(diag_msg.family, family)
 
-    # TODO: The kernel (at least 3.10) seems only to fill in the first 4 bytes
-    # of src and dst in the case of IPv4 addresses. This means we can't just do
-    # something like:
-    #  self.assertEqual(diag_msg.id.src, self.sock_diag.PaddedAddress(src))
-    # because the trailing bytes might not match.
-    # This seems like a bug because it might leaks kernel memory contents, but
-    # regardless, work around that here.
-    addrlen = {AF_INET: 4, AF_INET6: 16}[family]
+    self.sock_diag.FixupDiagMsg(diag_msg)
 
     src, sport = s.getsockname()[0:2]
+    self.assertEqual(diag_msg.id.src, self.sock_diag.PaddedAddress(src))
     self.assertEqual(diag_msg.id.sport, sport)
-    self.assertEqual(diag_msg.id.src[:addrlen],
-                     self.sock_diag.RawAddress(src))
 
     if self.sock_diag.GetDestinationAddress(diag_msg) not in ["0.0.0.0", "::"]:
       dst, dport = s.getpeername()[0:2]
-      self.assertEqual(diag_msg.id.dst[:addrlen],
-                       self.sock_diag.RawAddress(dst))
+      self.assertEqual(diag_msg.id.dst, self.sock_diag.PaddedAddress(dst))
       self.assertEqual(diag_msg.id.dport, dport)
     else:
-      assertRaisesErrno(errno.ENOTCONN, s.getpeername)
+      assertRaisesErrno(ENOTCONN, s.getpeername)
 
   def testFindsAllMySockets(self):
+    self.socketpairs = self._CreateLotsOfSockets()
     sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
                                                 states=ALL_NON_TIME_WAIT)
     self.assertGreaterEqual(len(sockets), NUM_SOCKETS)
@@ -103,9 +135,330 @@
     random.shuffle(socketpairs)
     for socketpair in socketpairs:
       for sock in socketpair:
+        # Check that we can find a diag_msg by scanning a dump.
         self.assertSockDiagMatchesSocket(
             sock,
-            self.sock_diag.GetSockDiagForFd(sock))
+            self.sock_diag.FindSockDiagFromFd(sock))
+        cookie = self.sock_diag.FindSockDiagFromFd(sock).id.cookie
+
+        # Check that we can find a diag_msg once we know the cookie.
+        req = self.sock_diag.DiagReqFromSocket(sock)
+        req.id.cookie = cookie
+        req.states = 1 << diag_msg.state
+        diag_msg = self.sock_diag.GetSockDiag(req)
+        self.assertSockDiagMatchesSocket(sock, diag_msg)
+
+  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
+  def testClosesSockets(self):
+    self.socketpairs = self._CreateLotsOfSockets()
+    for (addr, _, _), socketpair in self.socketpairs.iteritems():
+      # Close one of the sockets.
+      # This will send a RST that will close the other side as well.
+      s = random.choice(socketpair)
+      if random.randrange(0, 2) == 1:
+        self.sock_diag.CloseSocketFromFd(s)
+      else:
+        diag_msg = self.sock_diag.FindSockDiagFromFd(s)
+        family = AF_INET6 if ":" in addr else AF_INET
+
+        # Get the cookie wrong and ensure that we get an error and the socket
+        # is not closed.
+        real_cookie = diag_msg.id.cookie
+        diag_msg.id.cookie = os.urandom(len(real_cookie))
+        req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
+        self.assertRaisesErrno(ENOENT, self.sock_diag.CloseSocket, req)
+        self.assertSocketConnected(s)
+
+        # Now close it with the correct cookie.
+        req.id.cookie = real_cookie
+        self.sock_diag.CloseSocket(req)
+
+      # Check that both sockets in the pair are closed.
+      self.assertSocketsClosed(socketpair)
+
+  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
+  def testNonTcpSockets(self):
+    s = socket(AF_INET6, SOCK_DGRAM, 0)
+    s.connect(("::1", 53))
+    diag_msg = self.sock_diag.FindSockDiagFromFd(s)
+    self.assertRaisesErrno(EOPNOTSUPP, self.sock_diag.CloseSocketFromFd, s)
+
+  def testNonSockDiagCommand(self):
+    def DiagDump(code):
+      sock_id = self.sock_diag._EmptyInetDiagSockId()
+      req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff,
+                                     sock_id))
+      self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg)
+
+    op = sock_diag.SOCK_DIAG_BY_FAMILY
+    DiagDump(op)  # No errors? Good.
+    self.assertRaisesErrno(EINVAL, DiagDump, op + 17)
+
+  # TODO:
+  # Test that killing unix sockets returns EOPNOTSUPP.
+
+
+class SocketExceptionThread(threading.Thread):
+
+  def __init__(self, sock, operation):
+    self.exception = None
+    super(SocketExceptionThread, self).__init__()
+    self.daemon = True
+    self.sock = sock
+    self.operation = operation
+
+  def run(self):
+    try:
+      self.operation(self.sock)
+    except Exception, e:
+      self.exception = e
+
+
+# TODO: Take a tun fd as input, make this a utility class, and reuse at least
+# in forwarding_test.
+class TcpTest(SockDiagTest):
+
+  NOT_YET_ACCEPTED = -1
+
+  def setUp(self):
+    super(TcpTest, self).setUp()
+    self.sock_diag = sock_diag.SockDiag()
+    self.netid = random.choice(self.tuns.keys())
+
+  def OpenListenSocket(self, version):
+    self.port = packets.RandomPort()
+    family = {4: AF_INET, 6: AF_INET6}[version]
+    address = {4: "0.0.0.0", 6: "::"}[version]
+    s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
+    s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
+    s.bind((address, self.port))
+    # We haven't configured inbound iptables marking, so bind explicitly.
+    self.SelectInterface(s, self.netid, "mark")
+    s.listen(100)
+    return s
+
+  def _ReceiveAndExpectResponse(self, netid, packet, reply, msg):
+    pkt = super(TcpTest, self)._ReceiveAndExpectResponse(netid, packet,
+                                                         reply, msg)
+    self.last_packet = pkt
+    return pkt
+
+  def ReceivePacketOn(self, netid, packet):
+    super(TcpTest, self).ReceivePacketOn(netid, packet)
+    self.last_packet = packet
+
+  def RstPacket(self):
+    return packets.RST(self.version, self.myaddr, self.remoteaddr,
+                       self.last_packet)
+
+  def IncomingConnection(self, version, end_state, netid):
+    self.version = version
+    self.s = self.OpenListenSocket(version)
+    self.end_state = end_state
+
+    remoteaddr = self.remoteaddr = self.GetRemoteAddress(version)
+    myaddr = self.myaddr = self.MyAddress(version, netid)
+
+    if end_state == sock_diag.TCP_LISTEN:
+      return
+
+    desc, syn = packets.SYN(self.port, version, remoteaddr, myaddr)
+    synack_desc, synack = packets.SYNACK(version, myaddr, remoteaddr, syn)
+    msg = "Received %s, expected to see reply %s" % (desc, synack_desc)
+    reply = self._ReceiveAndExpectResponse(netid, syn, synack, msg)
+    if end_state == sock_diag.TCP_SYN_RECV:
+      return
+
+    establishing_ack = packets.ACK(version, remoteaddr, myaddr, reply)[1]
+    self.ReceivePacketOn(netid, establishing_ack)
+
+    if end_state == self.NOT_YET_ACCEPTED:
+      return
+
+    self.accepted, _ = self.s.accept()
+    if end_state == sock_diag.TCP_ESTABLISHED:
+      return
+
+    desc, data = packets.ACK(version, myaddr, remoteaddr, establishing_ack,
+                             payload=net_test.UDP_PAYLOAD)
+    self.accepted.send(net_test.UDP_PAYLOAD)
+    self.ExpectPacketOn(netid, msg + ": expecting %s" % desc, data)
+
+    desc, fin = packets.FIN(version, remoteaddr, myaddr, data)
+    fin = packets._GetIpLayer(version)(str(fin))
+    ack_desc, ack = packets.ACK(version, myaddr, remoteaddr, fin)
+    msg = "Received %s, expected to see reply %s" % (desc, ack_desc)
+
+    # TODO: Why can't we use this?
+    #   self._ReceiveAndExpectResponse(netid, fin, ack, msg)
+    self.ReceivePacketOn(netid, fin)
+    time.sleep(0.1)
+    self.ExpectPacketOn(netid, msg + ": expecting %s" % ack_desc, ack)
+    if end_state == sock_diag.TCP_CLOSE_WAIT:
+      return
+
+    raise ValueError("Invalid TCP state %d specified" % end_state)
+
+  def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True):
+    """Closes the socket and checks whether a RST is sent or not."""
+    if sock is not None:
+      self.assertIsNone(req, "Must specify sock or req, not both")
+      self.sock_diag.CloseSocketFromFd(sock)
+      self.assertRaisesErrno(EINVAL, sock.accept)
+    else:
+      self.assertIsNone(sock, "Must specify sock or req, not both")
+      self.sock_diag.CloseSocket(req)
+
+    if expect_reset:
+      desc, rst = self.RstPacket()
+      msg = "%s: expecting %s: " % (msg, desc)
+      self.ExpectPacketOn(self.netid, msg, rst)
+    else:
+      msg = "%s: " % msg
+      self.ExpectNoPacketsOn(self.netid, msg)
+
+    if sock is not None and do_close:
+      sock.close()
+
+  def CheckTcpReset(self, state, statename):
+    for version in [4, 6]:
+      msg = "Closing incoming IPv%d %s socket" % (version, statename)
+      self.IncomingConnection(version, state, self.netid)
+      self.CheckRstOnClose(self.s, None, False, msg)
+      if state != sock_diag.TCP_LISTEN:
+        msg = "Closing accepted IPv%d %s socket" % (version, statename)
+        self.CheckRstOnClose(self.accepted, None, True, msg)
+
+  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
+  def testTcpResets(self):
+    """Checks that closing sockets in appropriate states sends a RST."""
+    self.CheckTcpReset(sock_diag.TCP_LISTEN, "TCP_LISTEN")
+    self.CheckTcpReset(sock_diag.TCP_ESTABLISHED, "TCP_ESTABLISHED")
+    self.CheckTcpReset(sock_diag.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT")
+
+  def FindChildSockets(self, s):
+    """Finds the SYN_RECV child sockets of a given listening socket."""
+    d = self.sock_diag.FindSockDiagFromFd(self.s)
+    req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
+    req.states = 1 << sock_diag.TCP_SYN_RECV | 1 << sock_diag.TCP_ESTABLISHED
+    req.id.cookie = "\x00" * 8
+    children = self.sock_diag.Dump(req)
+    return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
+            for d, _ in children]
+
+  def CheckChildSocket(self, state, statename, parent_first):
+    for version in [4, 6]:
+      self.IncomingConnection(version, state, self.netid)
+
+      d = self.sock_diag.FindSockDiagFromFd(self.s)
+      parent = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
+      children = self.FindChildSockets(self.s)
+      self.assertEquals(1, len(children))
+
+      is_established = (state == self.NOT_YET_ACCEPTED)
+
+      # The new TCP listener code in 4.4 makes SYN_RECV sockets live in the
+      # regular TCP hash tables, and inet_diag_find_one_icsk can find them.
+      # Before 4.4, we can see those sockets in dumps, but we can't fetch
+      # or close them.
+      can_close_children = is_established or net_test.LINUX_VERSION >= (4, 4)
+
+      for child in children:
+        if can_close_children:
+          self.sock_diag.GetSockDiag(child)  # No errors? Good, child found.
+        else:
+          self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child)
+
+      def CloseParent(expect_reset):
+        msg = "Closing parent IPv%d %s socket %s child" % (
+            version, statename, "before" if parent_first else "after")
+        self.CheckRstOnClose(self.s, None, expect_reset, msg)
+        self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, parent)
+
+      def CheckChildrenClosed():
+        for child in children:
+          self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child)
+
+      def CloseChildren():
+        for child in children:
+          msg = "Closing child IPv%d %s socket %s parent" % (
+              version, statename, "after" if parent_first else "before")
+          self.sock_diag.GetSockDiag(child)
+          self.CheckRstOnClose(None, child, is_established, msg)
+          self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child)
+        CheckChildrenClosed()
+
+      if parent_first:
+        # Closing the parent will close child sockets, which will send a RST,
+        # iff they are already established.
+        CloseParent(is_established)
+        if is_established:
+          CheckChildrenClosed()
+        elif can_close_children:
+          CloseChildren()
+          CheckChildrenClosed()
+        self.s.close()
+      else:
+        if can_close_children:
+          CloseChildren()
+        CloseParent(False)
+        self.s.close()
+
+  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
+  def testChildSockets(self):
+    self.CheckChildSocket(sock_diag.TCP_SYN_RECV, "TCP_SYN_RECV", False)
+    self.CheckChildSocket(sock_diag.TCP_SYN_RECV, "TCP_SYN_RECV", True)
+    self.CheckChildSocket(self.NOT_YET_ACCEPTED, "not yet accepted", False)
+    self.CheckChildSocket(self.NOT_YET_ACCEPTED, "not yet accepted", True)
+
+  def CloseDuringBlockingCall(self, sock, call, expected_errno):
+    thread = SocketExceptionThread(sock, call)
+    thread.start()
+    time.sleep(0.1)
+    self.sock_diag.CloseSocketFromFd(sock)
+    thread.join(1)
+    self.assertFalse(thread.is_alive())
+    self.assertIsNotNone(thread.exception)
+    self.assertTrue(isinstance(thread.exception, IOError),
+                    "Expected IOError, got %s" % thread.exception)
+    self.assertEqual(expected_errno, thread.exception.errno)
+    self.assertSocketClosed(sock)
+
+  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
+  def testAcceptInterrupted(self):
+    """Tests that accept() is interrupted by SOCK_DESTROY."""
+    for version in [4, 6]:
+      self.IncomingConnection(version, sock_diag.TCP_LISTEN, self.netid)
+      self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL)
+      self.assertRaisesErrno(ECONNABORTED, self.s.send, "foo")
+      self.assertRaisesErrno(EINVAL, self.s.accept)
+
+  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
+  def testReadInterrupted(self):
+    """Tests that read() is interrupted by SOCK_DESTROY."""
+    for version in [4, 6]:
+      self.IncomingConnection(version, sock_diag.TCP_ESTABLISHED, self.netid)
+      self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096),
+                                   ECONNABORTED)
+      self.assertRaisesErrno(EPIPE, self.accepted.send, "foo")
+
+  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
+  def testConnectInterrupted(self):
+    """Tests that connect() is interrupted by SOCK_DESTROY."""
+    for version in [4, 6]:
+      family = {4: AF_INET, 6: AF_INET6}[version]
+      s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
+      self.SelectInterface(s, self.netid, "mark")
+      remoteaddr = self.GetRemoteAddress(version)
+      s.bind(("", 0))
+      _, sport = s.getsockname()[:2]
+      self.CloseDuringBlockingCall(
+          s, lambda sock: sock.connect((remoteaddr, 53)), ECONNABORTED)
+      desc, syn = packets.SYN(53, version, self.MyAddress(version, self.netid),
+                              remoteaddr, sport=sport, seq=None)
+      self.ExpectPacketOn(self.netid, desc, syn)
+      msg = "SOCK_DESTROY of socket in connect, expected no RST"
+      self.ExpectNoPacketsOn(self.netid, msg)
 
 
 if __name__ == "__main__":