Merge "Fix for simpleperf report on 'fugu'."
diff --git a/f2fs_utils/Android.mk b/f2fs_utils/Android.mk
index 0f39116..2bc190f 100644
--- a/f2fs_utils/Android.mk
+++ b/f2fs_utils/Android.mk
@@ -17,6 +17,7 @@
LOCAL_SRC_FILES := f2fs_ioutils.c
LOCAL_C_INCLUDES := external/f2fs-tools/include external/f2fs-tools/mkfs
LOCAL_STATIC_LIBRARIES := \
+ libselinux \
libsparse_host \
libext2_uuid-host \
libz
diff --git a/tests/net_test/iproute.py b/tests/net_test/iproute.py
index 7c9eec0..e0f2837 100644
--- a/tests/net_test/iproute.py
+++ b/tests/net_test/iproute.py
@@ -454,10 +454,7 @@
# implement parsing dump results.
raise NotImplementedError("IPv4 RTM_GETADDR not implemented.")
self._Address(6, RTM_GETADDR, address, 0, 0, RT_SCOPE_UNIVERSE, ifindex)
- data = self._Recv()
- if NLMsgHdr(data).type == NLMSG_ERROR:
- self._ParseAck(data)
- return self._ParseNLMsg(data, IfAddrMsg)[0]
+ return self._GetMsg(IfAddrMsg)
def _Route(self, version, command, table, dest, prefixlen, nexthop, dev,
mark, uid):
diff --git a/tests/net_test/net_test.py b/tests/net_test/net_test.py
index f108aa8..d7ea013 100755
--- a/tests/net_test/net_test.py
+++ b/tests/net_test/net_test.py
@@ -56,6 +56,8 @@
IPV6_FL_S_EXCL = 1
IPV6_FL_S_ANY = 255
+IFNAMSIZ = 16
+
IPV4_PING = "\x08\x00\x00\x00\x0a\xce\x00\x03"
IPV6_PING = "\x80\x00\x00\x00\x0a\xce\x00\x03"
@@ -171,9 +173,9 @@
def GetInterfaceIndex(ifname):
s = IPv4PingSocket()
- ifr = struct.pack("16si", ifname, 0)
+ ifr = struct.pack("%dsi" % IFNAMSIZ, ifname, 0)
ifr = fcntl.ioctl(s, scapy.SIOCGIFINDEX, ifr)
- return struct.unpack("16si", ifr)[1]
+ return struct.unpack("%dsi" % IFNAMSIZ, ifr)[1]
def SetInterfaceHWAddr(ifname, hwaddr):
@@ -182,20 +184,20 @@
hwaddr = hwaddr.decode("hex")
if len(hwaddr) != 6:
raise ValueError("Unknown hardware address length %d" % len(hwaddr))
- ifr = struct.pack("16sH6s", ifname, scapy.ARPHDR_ETHER, hwaddr)
+ ifr = struct.pack("%dsH6s" % IFNAMSIZ, ifname, scapy.ARPHDR_ETHER, hwaddr)
fcntl.ioctl(s, SIOCSIFHWADDR, ifr)
def SetInterfaceState(ifname, up):
s = IPv4PingSocket()
- ifr = struct.pack("16sH", ifname, 0)
+ ifr = struct.pack("%dsH" % IFNAMSIZ, ifname, 0)
ifr = fcntl.ioctl(s, scapy.SIOCGIFFLAGS, ifr)
- _, flags = struct.unpack("16sH", ifr)
+ _, flags = struct.unpack("%dsH" % IFNAMSIZ, ifr)
if up:
flags |= scapy.IFF_UP
else:
flags &= ~scapy.IFF_UP
- ifr = struct.pack("16sH", ifname, flags)
+ ifr = struct.pack("%dsH" % IFNAMSIZ, ifname, flags)
ifr = fcntl.ioctl(s, scapy.SIOCSIFFLAGS, ifr)
diff --git a/tests/net_test/netlink.py b/tests/net_test/netlink.py
index 514ad08..f245901 100644
--- a/tests/net_test/netlink.py
+++ b/tests/net_test/netlink.py
@@ -200,6 +200,12 @@
data = data[attrlen:]
return (nlmsg, attributes), data
+ def _GetMsg(self, msgtype):
+ data = self._Recv()
+ if NLMsgHdr(data).type == NLMSG_ERROR:
+ self._ParseAck(data)
+ return self._ParseNLMsg(data, msgtype)[0]
+
def _GetMsgList(self, msgtype, data, expect_done):
out = []
while data:
diff --git a/tests/net_test/packets.py b/tests/net_test/packets.py
index d92a97e..c02adc0 100644
--- a/tests/net_test/packets.py
+++ b/tests/net_test/packets.py
@@ -120,11 +120,12 @@
def FIN(version, srcaddr, dstaddr, packet):
ip = _GetIpLayer(version)
original = packet.getlayer("TCP")
- was_fin = (original.flags & TCP_FIN) != 0
+ was_syn_or_fin = (original.flags & (TCP_SYN | TCP_FIN)) != 0
+ ack_delta = was_syn_or_fin + len(original.payload)
return ("TCP FIN",
ip(src=srcaddr, dst=dstaddr) /
scapy.TCP(sport=original.dport, dport=original.sport,
- ack=original.seq + was_fin, seq=original.ack,
+ ack=original.seq + ack_delta, seq=original.ack,
flags=TCP_ACK | TCP_FIN, window=TCP_WINDOW))
def GRE(version, srcaddr, dstaddr, proto, packet):
diff --git a/tests/net_test/run_net_test.sh b/tests/net_test/run_net_test.sh
index 85dc122..080aac7 100755
--- a/tests/net_test/run_net_test.sh
+++ b/tests/net_test/run_net_test.sh
@@ -1,7 +1,8 @@
#!/bin/bash
# Kernel configuration options.
-OPTIONS=" IPV6 IPV6_ROUTER_PREF IPV6_MULTIPLE_TABLES IPV6_ROUTE_INFO"
+OPTIONS=" DEBUG_SPINLOCK DEBUG_ATOMIC_SLEEP DEBUG_MUTEXES DEBUG_RT_MUTEXES"
+OPTIONS="$OPTIONS IPV6 IPV6_ROUTER_PREF IPV6_MULTIPLE_TABLES IPV6_ROUTE_INFO"
OPTIONS="$OPTIONS TUN SYN_COOKIES IP_ADVANCED_ROUTER IP_MULTIPLE_TABLES"
OPTIONS="$OPTIONS NETFILTER NETFILTER_ADVANCED NETFILTER_XTABLES"
OPTIONS="$OPTIONS NETFILTER_XT_MARK NETFILTER_XT_TARGET_MARK"
@@ -11,6 +12,7 @@
OPTIONS="$OPTIONS CONFIG_NETFILTER_XT_TARGET_NFLOG"
OPTIONS="$OPTIONS CONFIG_NETFILTER_XT_MATCH_QUOTA CONFIG_NETFILTER_XT_MATCH_QUOTA2"
OPTIONS="$OPTIONS CONFIG_NETFILTER_XT_MATCH_QUOTA2_LOG"
+OPTIONS="$OPTIONS CONFIG_INET_UDP_DIAG CONFIG_INET_DIAG_DESTROY"
# For 3.1 kernels, where devtmpfs is not on by default.
OPTIONS="$OPTIONS DEVTMPFS DEVTMPFS_MOUNT"
diff --git a/tests/net_test/sock_diag.py b/tests/net_test/sock_diag.py
index 8c70eb3..a9de345 100755
--- a/tests/net_test/sock_diag.py
+++ b/tests/net_test/sock_diag.py
@@ -131,13 +131,17 @@
def _EmptyInetDiagSockId():
return InetDiagSockId(("\x00" * len(InetDiagSockId)))
+ def Dump(self, diag_req):
+ out = self._Dump(SOCK_DIAG_BY_FAMILY, diag_req, InetDiagMsg)
+ return out
+
def DumpSockets(self, family, protocol, ext, states, sock_id):
"""Dumps sockets matching the specified parameters."""
if sock_id is None:
sock_id = self._EmptyInetDiagSockId()
diag_req = InetDiagReqV2((family, protocol, ext, states, sock_id))
- return self._Dump(SOCK_DIAG_BY_FAMILY, diag_req, InetDiagMsg)
+ return self.Dump(diag_req)
def DumpAllInetSockets(self, protocol, sock_id=None, ext=0, states=0xffffffff):
# DumpSockets(AF_UNSPEC) does not result in dumping all inet sockets, it
@@ -177,12 +181,26 @@
padded += "\x00" * (16 - len(padded))
return padded
+ # For IPv4 addresses, the kernel seems only to fill in the first 4 bytes of
+ # src and dst, leaving the others unspecified. This seems like a bug because
+ # it might leak kernel memory contents, but regardless, work around it.
+ @staticmethod
+ def FixupDiagMsg(d):
+ if d.family == AF_INET:
+ d.id.src = d.id.src[:4] + "\x00" * 12
+ d.id.dst = d.id.dst[:4] + "\x00" * 12
+
@staticmethod
def DiagReqFromSocket(s):
"""Creates an InetDiagReqV2 that matches the specified socket."""
family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN)
protocol = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_PROTOCOL)
- iface = s.getsockopt(SOL_SOCKET, net_test.SO_BINDTODEVICE)
+ if net_test.LINUX_VERSION >= (3, 8):
+ iface = s.getsockopt(SOL_SOCKET, net_test.SO_BINDTODEVICE,
+ net_test.IFNAMSIZ)
+ iface = GetInterfaceIndex(iface) if iface else 0
+ else:
+ iface = 0
src, sport = s.getsockname()[:2]
try:
dst, dport = s.getpeername()[:2]
@@ -197,19 +215,35 @@
sock_id = InetDiagSockId((sport, dport, src, dst, iface, "\x00" * 8))
return InetDiagReqV2((family, protocol, 0, 0xffffffff, sock_id))
- def GetSockDiagForFd(self, s):
- """Gets an InetDiagMsg from the kernel for the specified socket."""
- req = self.DiagReqFromSocket(s)
- for diag_msg, attrs in self._Dump(SOCK_DIAG_BY_FAMILY, req, InetDiagMsg):
+ def FindSockDiagFromReq(self, req):
+ for diag_msg, attrs in self.Dump(req):
return diag_msg
raise ValueError("Dump of %s returned no sockets" % req)
- def GetSockDiag(self, family, protocol, sock_id, ext=0, states=0xffffffff):
- """Gets an InetDiagMsg from the kernel for the specified parameters."""
- req = InetDiagReqV2((family, protocol, ext, states, sock_id))
+ def FindSockDiagFromFd(self, s):
+ """Gets an InetDiagMsg from the kernel for the specified socket."""
+ req = self.DiagReqFromSocket(s)
+ return self.FindSockDiagFromReq(req)
+
+ def GetSockDiag(self, req):
+ """Gets an InetDiagMsg from the kernel for the specified request."""
self._SendNlRequest(SOCK_DIAG_BY_FAMILY, req.Pack(), netlink.NLM_F_REQUEST)
- data = self._Recv()
- return self._ParseNLMsg(data, InetDiagMsg)[0]
+ return self._GetMsg(InetDiagMsg)[0]
+
+ @staticmethod
+ def DiagReqFromDiagMsg(d, protocol):
+ """Constructs a diag_req from a diag_msg the kernel has given us."""
+ return InetDiagReqV2((d.family, protocol, 0, 1 << d.state, d.id))
+
+ def CloseSocket(self, req):
+ self._SendNlRequest(SOCK_DESTROY, req.Pack(),
+ netlink.NLM_F_REQUEST | netlink.NLM_F_ACK)
+
+ def CloseSocketFromFd(self, s):
+ diag_msg = self.FindSockDiagFromFd(s)
+ protocol = s.getsockopt(SOL_SOCKET, net_test.SO_PROTOCOL)
+ req = self.DiagReqFromDiagMsg(diag_msg, protocol)
+ return self.CloseSocket(req)
if __name__ == "__main__":
diff --git a/tests/net_test/sock_diag_test.py b/tests/net_test/sock_diag_test.py
index 2803bd2..5975931 100755
--- a/tests/net_test/sock_diag_test.py
+++ b/tests/net_test/sock_diag_test.py
@@ -14,7 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import errno
+from errno import *
+import os
import random
from socket import *
import time
@@ -26,12 +27,15 @@
import net_test
import packets
import sock_diag
+import threading
NUM_SOCKETS = 100
ALL_NON_TIME_WAIT = 0xffffffff & ~(1 << sock_diag.TCP_TIME_WAIT)
+# TODO: Backport SOCK_DESTROY and delete this.
+HAVE_SOCK_DESTROY = net_test.LINUX_VERSION >= (4, 4)
class SockDiagTest(multinetwork_base.MultiNetworkBaseTest):
@@ -48,39 +52,67 @@
return socketpairs
def setUp(self):
+ super(SockDiagTest, self).setUp()
self.sock_diag = sock_diag.SockDiag()
- self.socketpairs = self._CreateLotsOfSockets()
+ self.socketpairs = {}
def tearDown(self):
[s.close() for socketpair in self.socketpairs.values() for s in socketpair]
+ super(SockDiagTest, self).tearDown()
+
+ def testFixupDiagMsg(self):
+ src = "0a00fa02303030312030312038302031"
+ dst = "0808080841414141414141416f0a3230"
+ cookie = "4078678100000000"
+ sockid = sock_diag.InetDiagSockId((47436, 32069,
+ src.decode("hex"), dst.decode("hex"), 0,
+ cookie.decode("hex")))
+ msg4 = sock_diag.InetDiagMsg((AF_INET, IPPROTO_TCP, 0,
+ sock_diag.TCP_SYN_RECV, sockid,
+ 980, 123, 456, 789, 5555))
+ # Make a copy, cstructs are mutable.
+ msg6 = sock_diag.InetDiagMsg(msg4.Pack())
+ msg6.family = AF_INET6
+
+ fixed6 = sock_diag.InetDiagMsg(msg6.Pack())
+ self.sock_diag.FixupDiagMsg(fixed6)
+ self.assertEquals(msg6.Pack(), fixed6.Pack())
+
+ fixed4 = sock_diag.InetDiagMsg(msg4.Pack())
+ self.sock_diag.FixupDiagMsg(fixed4)
+ msg4.id.src = src.decode("hex")[:4] + 12 * "\x00"
+ msg4.id.dst = dst.decode("hex")[:4] + 12 * "\x00"
+ self.assertEquals(msg4.Pack(), fixed4.Pack())
+
+ def assertSocketClosed(self, sock):
+ self.assertRaisesErrno(ENOTCONN, sock.getpeername)
+
+ def assertSocketConnected(self, sock):
+ sock.getpeername() # No errors? Socket is alive and connected.
+
+ def assertSocketsClosed(self, socketpair):
+ for sock in socketpair:
+ self.assertSocketClosed(sock)
def assertSockDiagMatchesSocket(self, s, diag_msg):
family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN)
self.assertEqual(diag_msg.family, family)
- # TODO: The kernel (at least 3.10) seems only to fill in the first 4 bytes
- # of src and dst in the case of IPv4 addresses. This means we can't just do
- # something like:
- # self.assertEqual(diag_msg.id.src, self.sock_diag.PaddedAddress(src))
- # because the trailing bytes might not match.
- # This seems like a bug because it might leaks kernel memory contents, but
- # regardless, work around that here.
- addrlen = {AF_INET: 4, AF_INET6: 16}[family]
+ self.sock_diag.FixupDiagMsg(diag_msg)
src, sport = s.getsockname()[0:2]
+ self.assertEqual(diag_msg.id.src, self.sock_diag.PaddedAddress(src))
self.assertEqual(diag_msg.id.sport, sport)
- self.assertEqual(diag_msg.id.src[:addrlen],
- self.sock_diag.RawAddress(src))
if self.sock_diag.GetDestinationAddress(diag_msg) not in ["0.0.0.0", "::"]:
dst, dport = s.getpeername()[0:2]
- self.assertEqual(diag_msg.id.dst[:addrlen],
- self.sock_diag.RawAddress(dst))
+ self.assertEqual(diag_msg.id.dst, self.sock_diag.PaddedAddress(dst))
self.assertEqual(diag_msg.id.dport, dport)
else:
- assertRaisesErrno(errno.ENOTCONN, s.getpeername)
+ assertRaisesErrno(ENOTCONN, s.getpeername)
def testFindsAllMySockets(self):
+ self.socketpairs = self._CreateLotsOfSockets()
sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
states=ALL_NON_TIME_WAIT)
self.assertGreaterEqual(len(sockets), NUM_SOCKETS)
@@ -103,9 +135,330 @@
random.shuffle(socketpairs)
for socketpair in socketpairs:
for sock in socketpair:
+ # Check that we can find a diag_msg by scanning a dump.
self.assertSockDiagMatchesSocket(
sock,
- self.sock_diag.GetSockDiagForFd(sock))
+ self.sock_diag.FindSockDiagFromFd(sock))
+ cookie = self.sock_diag.FindSockDiagFromFd(sock).id.cookie
+
+ # Check that we can find a diag_msg once we know the cookie.
+ req = self.sock_diag.DiagReqFromSocket(sock)
+ req.id.cookie = cookie
+ req.states = 1 << diag_msg.state
+ diag_msg = self.sock_diag.GetSockDiag(req)
+ self.assertSockDiagMatchesSocket(sock, diag_msg)
+
+ @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
+ def testClosesSockets(self):
+ self.socketpairs = self._CreateLotsOfSockets()
+ for (addr, _, _), socketpair in self.socketpairs.iteritems():
+ # Close one of the sockets.
+ # This will send a RST that will close the other side as well.
+ s = random.choice(socketpair)
+ if random.randrange(0, 2) == 1:
+ self.sock_diag.CloseSocketFromFd(s)
+ else:
+ diag_msg = self.sock_diag.FindSockDiagFromFd(s)
+ family = AF_INET6 if ":" in addr else AF_INET
+
+ # Get the cookie wrong and ensure that we get an error and the socket
+ # is not closed.
+ real_cookie = diag_msg.id.cookie
+ diag_msg.id.cookie = os.urandom(len(real_cookie))
+ req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
+ self.assertRaisesErrno(ENOENT, self.sock_diag.CloseSocket, req)
+ self.assertSocketConnected(s)
+
+ # Now close it with the correct cookie.
+ req.id.cookie = real_cookie
+ self.sock_diag.CloseSocket(req)
+
+ # Check that both sockets in the pair are closed.
+ self.assertSocketsClosed(socketpair)
+
+ @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
+ def testNonTcpSockets(self):
+ s = socket(AF_INET6, SOCK_DGRAM, 0)
+ s.connect(("::1", 53))
+ diag_msg = self.sock_diag.FindSockDiagFromFd(s)
+ self.assertRaisesErrno(EOPNOTSUPP, self.sock_diag.CloseSocketFromFd, s)
+
+ def testNonSockDiagCommand(self):
+ def DiagDump(code):
+ sock_id = self.sock_diag._EmptyInetDiagSockId()
+ req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff,
+ sock_id))
+ self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg)
+
+ op = sock_diag.SOCK_DIAG_BY_FAMILY
+ DiagDump(op) # No errors? Good.
+ self.assertRaisesErrno(EINVAL, DiagDump, op + 17)
+
+ # TODO:
+ # Test that killing unix sockets returns EOPNOTSUPP.
+
+
+class SocketExceptionThread(threading.Thread):
+
+ def __init__(self, sock, operation):
+ self.exception = None
+ super(SocketExceptionThread, self).__init__()
+ self.daemon = True
+ self.sock = sock
+ self.operation = operation
+
+ def run(self):
+ try:
+ self.operation(self.sock)
+ except Exception, e:
+ self.exception = e
+
+
+# TODO: Take a tun fd as input, make this a utility class, and reuse at least
+# in forwarding_test.
+class TcpTest(SockDiagTest):
+
+ NOT_YET_ACCEPTED = -1
+
+ def setUp(self):
+ super(TcpTest, self).setUp()
+ self.sock_diag = sock_diag.SockDiag()
+ self.netid = random.choice(self.tuns.keys())
+
+ def OpenListenSocket(self, version):
+ self.port = packets.RandomPort()
+ family = {4: AF_INET, 6: AF_INET6}[version]
+ address = {4: "0.0.0.0", 6: "::"}[version]
+ s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
+ s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
+ s.bind((address, self.port))
+ # We haven't configured inbound iptables marking, so bind explicitly.
+ self.SelectInterface(s, self.netid, "mark")
+ s.listen(100)
+ return s
+
+ def _ReceiveAndExpectResponse(self, netid, packet, reply, msg):
+ pkt = super(TcpTest, self)._ReceiveAndExpectResponse(netid, packet,
+ reply, msg)
+ self.last_packet = pkt
+ return pkt
+
+ def ReceivePacketOn(self, netid, packet):
+ super(TcpTest, self).ReceivePacketOn(netid, packet)
+ self.last_packet = packet
+
+ def RstPacket(self):
+ return packets.RST(self.version, self.myaddr, self.remoteaddr,
+ self.last_packet)
+
+ def IncomingConnection(self, version, end_state, netid):
+ self.version = version
+ self.s = self.OpenListenSocket(version)
+ self.end_state = end_state
+
+ remoteaddr = self.remoteaddr = self.GetRemoteAddress(version)
+ myaddr = self.myaddr = self.MyAddress(version, netid)
+
+ if end_state == sock_diag.TCP_LISTEN:
+ return
+
+ desc, syn = packets.SYN(self.port, version, remoteaddr, myaddr)
+ synack_desc, synack = packets.SYNACK(version, myaddr, remoteaddr, syn)
+ msg = "Received %s, expected to see reply %s" % (desc, synack_desc)
+ reply = self._ReceiveAndExpectResponse(netid, syn, synack, msg)
+ if end_state == sock_diag.TCP_SYN_RECV:
+ return
+
+ establishing_ack = packets.ACK(version, remoteaddr, myaddr, reply)[1]
+ self.ReceivePacketOn(netid, establishing_ack)
+
+ if end_state == self.NOT_YET_ACCEPTED:
+ return
+
+ self.accepted, _ = self.s.accept()
+ if end_state == sock_diag.TCP_ESTABLISHED:
+ return
+
+ desc, data = packets.ACK(version, myaddr, remoteaddr, establishing_ack,
+ payload=net_test.UDP_PAYLOAD)
+ self.accepted.send(net_test.UDP_PAYLOAD)
+ self.ExpectPacketOn(netid, msg + ": expecting %s" % desc, data)
+
+ desc, fin = packets.FIN(version, remoteaddr, myaddr, data)
+ fin = packets._GetIpLayer(version)(str(fin))
+ ack_desc, ack = packets.ACK(version, myaddr, remoteaddr, fin)
+ msg = "Received %s, expected to see reply %s" % (desc, ack_desc)
+
+ # TODO: Why can't we use this?
+ # self._ReceiveAndExpectResponse(netid, fin, ack, msg)
+ self.ReceivePacketOn(netid, fin)
+ time.sleep(0.1)
+ self.ExpectPacketOn(netid, msg + ": expecting %s" % ack_desc, ack)
+ if end_state == sock_diag.TCP_CLOSE_WAIT:
+ return
+
+ raise ValueError("Invalid TCP state %d specified" % end_state)
+
+ def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True):
+ """Closes the socket and checks whether a RST is sent or not."""
+ if sock is not None:
+ self.assertIsNone(req, "Must specify sock or req, not both")
+ self.sock_diag.CloseSocketFromFd(sock)
+ self.assertRaisesErrno(EINVAL, sock.accept)
+ else:
+ self.assertIsNone(sock, "Must specify sock or req, not both")
+ self.sock_diag.CloseSocket(req)
+
+ if expect_reset:
+ desc, rst = self.RstPacket()
+ msg = "%s: expecting %s: " % (msg, desc)
+ self.ExpectPacketOn(self.netid, msg, rst)
+ else:
+ msg = "%s: " % msg
+ self.ExpectNoPacketsOn(self.netid, msg)
+
+ if sock is not None and do_close:
+ sock.close()
+
+ def CheckTcpReset(self, state, statename):
+ for version in [4, 6]:
+ msg = "Closing incoming IPv%d %s socket" % (version, statename)
+ self.IncomingConnection(version, state, self.netid)
+ self.CheckRstOnClose(self.s, None, False, msg)
+ if state != sock_diag.TCP_LISTEN:
+ msg = "Closing accepted IPv%d %s socket" % (version, statename)
+ self.CheckRstOnClose(self.accepted, None, True, msg)
+
+ @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
+ def testTcpResets(self):
+ """Checks that closing sockets in appropriate states sends a RST."""
+ self.CheckTcpReset(sock_diag.TCP_LISTEN, "TCP_LISTEN")
+ self.CheckTcpReset(sock_diag.TCP_ESTABLISHED, "TCP_ESTABLISHED")
+ self.CheckTcpReset(sock_diag.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT")
+
+ def FindChildSockets(self, s):
+ """Finds the SYN_RECV child sockets of a given listening socket."""
+ d = self.sock_diag.FindSockDiagFromFd(self.s)
+ req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
+ req.states = 1 << sock_diag.TCP_SYN_RECV | 1 << sock_diag.TCP_ESTABLISHED
+ req.id.cookie = "\x00" * 8
+ children = self.sock_diag.Dump(req)
+ return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
+ for d, _ in children]
+
+ def CheckChildSocket(self, state, statename, parent_first):
+ for version in [4, 6]:
+ self.IncomingConnection(version, state, self.netid)
+
+ d = self.sock_diag.FindSockDiagFromFd(self.s)
+ parent = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
+ children = self.FindChildSockets(self.s)
+ self.assertEquals(1, len(children))
+
+ is_established = (state == self.NOT_YET_ACCEPTED)
+
+ # The new TCP listener code in 4.4 makes SYN_RECV sockets live in the
+ # regular TCP hash tables, and inet_diag_find_one_icsk can find them.
+ # Before 4.4, we can see those sockets in dumps, but we can't fetch
+ # or close them.
+ can_close_children = is_established or net_test.LINUX_VERSION >= (4, 4)
+
+ for child in children:
+ if can_close_children:
+ self.sock_diag.GetSockDiag(child) # No errors? Good, child found.
+ else:
+ self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child)
+
+ def CloseParent(expect_reset):
+ msg = "Closing parent IPv%d %s socket %s child" % (
+ version, statename, "before" if parent_first else "after")
+ self.CheckRstOnClose(self.s, None, expect_reset, msg)
+ self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, parent)
+
+ def CheckChildrenClosed():
+ for child in children:
+ self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child)
+
+ def CloseChildren():
+ for child in children:
+ msg = "Closing child IPv%d %s socket %s parent" % (
+ version, statename, "after" if parent_first else "before")
+ self.sock_diag.GetSockDiag(child)
+ self.CheckRstOnClose(None, child, is_established, msg)
+ self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child)
+ CheckChildrenClosed()
+
+ if parent_first:
+ # Closing the parent will close child sockets, which will send a RST,
+ # iff they are already established.
+ CloseParent(is_established)
+ if is_established:
+ CheckChildrenClosed()
+ elif can_close_children:
+ CloseChildren()
+ CheckChildrenClosed()
+ self.s.close()
+ else:
+ if can_close_children:
+ CloseChildren()
+ CloseParent(False)
+ self.s.close()
+
+ @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
+ def testChildSockets(self):
+ self.CheckChildSocket(sock_diag.TCP_SYN_RECV, "TCP_SYN_RECV", False)
+ self.CheckChildSocket(sock_diag.TCP_SYN_RECV, "TCP_SYN_RECV", True)
+ self.CheckChildSocket(self.NOT_YET_ACCEPTED, "not yet accepted", False)
+ self.CheckChildSocket(self.NOT_YET_ACCEPTED, "not yet accepted", True)
+
+ def CloseDuringBlockingCall(self, sock, call, expected_errno):
+ thread = SocketExceptionThread(sock, call)
+ thread.start()
+ time.sleep(0.1)
+ self.sock_diag.CloseSocketFromFd(sock)
+ thread.join(1)
+ self.assertFalse(thread.is_alive())
+ self.assertIsNotNone(thread.exception)
+ self.assertTrue(isinstance(thread.exception, IOError),
+ "Expected IOError, got %s" % thread.exception)
+ self.assertEqual(expected_errno, thread.exception.errno)
+ self.assertSocketClosed(sock)
+
+ @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
+ def testAcceptInterrupted(self):
+ """Tests that accept() is interrupted by SOCK_DESTROY."""
+ for version in [4, 6]:
+ self.IncomingConnection(version, sock_diag.TCP_LISTEN, self.netid)
+ self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL)
+ self.assertRaisesErrno(ECONNABORTED, self.s.send, "foo")
+ self.assertRaisesErrno(EINVAL, self.s.accept)
+
+ @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
+ def testReadInterrupted(self):
+ """Tests that read() is interrupted by SOCK_DESTROY."""
+ for version in [4, 6]:
+ self.IncomingConnection(version, sock_diag.TCP_ESTABLISHED, self.netid)
+ self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096),
+ ECONNABORTED)
+ self.assertRaisesErrno(EPIPE, self.accepted.send, "foo")
+
+ @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
+ def testConnectInterrupted(self):
+ """Tests that connect() is interrupted by SOCK_DESTROY."""
+ for version in [4, 6]:
+ family = {4: AF_INET, 6: AF_INET6}[version]
+ s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
+ self.SelectInterface(s, self.netid, "mark")
+ remoteaddr = self.GetRemoteAddress(version)
+ s.bind(("", 0))
+ _, sport = s.getsockname()[:2]
+ self.CloseDuringBlockingCall(
+ s, lambda sock: sock.connect((remoteaddr, 53)), ECONNABORTED)
+ desc, syn = packets.SYN(53, version, self.MyAddress(version, self.netid),
+ remoteaddr, sport=sport, seq=None)
+ self.ExpectPacketOn(self.netid, desc, syn)
+ msg = "SOCK_DESTROY of socket in connect, expected no RST"
+ self.ExpectNoPacketsOn(self.netid, msg)
if __name__ == "__main__":