blob: 4b1d0557425a47731fc7f007c827e8272c64edae [file] [log] [blame]
# Copyright 2015 The Android Open Source Project
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=g-bad-todo,g-bad-file-header,wildcard-import
from errno import * # pylint: disable=wildcard-import
import os
import random
import select
from socket import * # pylint: disable=wildcard-import
import struct
import threading
import time
import unittest
import multinetwork_base
import net_test
import packets
import sock_diag
import tcp_test
def HaveUdpDiag():
"""Checks if the current kernel has config CONFIG_INET_UDP_DIAG enabled.
This config is required for device running 4.9 kernel that ship with P, In
this case always assume the config is there and use the tests to check if the
config is enabled as required.
For all ther other kernel version, there is no way to tell whether a dump
succeeded: if the appropriate handler wasn't found, __inet_diag_dump just
returns an empty result instead of an error. So, just check to see if a UDP
dump returns no sockets when we know it should return one. If not, some tests
will be skipped.
True if the kernel is 4.9 or above, or the CONFIG_INET_UDP_DIAG is enabled.
False otherwise.
return True;
s = socket(AF_INET6, SOCK_DGRAM, 0)
s.bind(("::", 0))
sd = sock_diag.SockDiag()
have_udp_diag = len(sd.DumpAllInetSockets(IPPROTO_UDP, "")) > 0
return have_udp_diag
def HaveSctp():
if net_test.LINUX_VERSION < (4, 7, 0):
return False
return True
except IOError:
return False
HAVE_UDP_DIAG = HaveUdpDiag()
HAVE_SCTP = HaveSctp()
class SockDiagBaseTest(multinetwork_base.MultiNetworkBaseTest):
"""Basic tests for SOCK_DIAG functionality.
Relevant kernel commits:
ab4a727 net: inet_diag: zero out uninitialized idiag_{src,dst} fields
99ee451 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
3eb409b net: inet_diag: zero out uninitialized idiag_{src,dst} fields
f77e059 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
e603010 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
525ee59 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
def _CreateLotsOfSockets(socktype):
# Dict mapping (addr, sport, dport) tuples to socketpairs.
socketpairs = {}
for _ in xrange(NUM_SOCKETS):
family, addr = random.choice([
(AF_INET, ""),
(AF_INET6, "::1"),
(AF_INET6, "::ffff:")])
socketpair = net_test.CreateSocketPair(family, socktype, addr)
sport, dport = (socketpair[0].getsockname()[1],
socketpairs[(addr, sport, dport)] = socketpair
return socketpairs
def assertSocketClosed(self, sock):
self.assertRaisesErrno(ENOTCONN, sock.getpeername)
def assertSocketConnected(self, sock):
sock.getpeername() # No errors? Socket is alive and connected.
def assertSocketsClosed(self, socketpair):
for sock in socketpair:
def assertMarkIs(self, mark, attrs):
self.assertEqual(mark, attrs.get("INET_DIAG_MARK", None))
def assertSockInfoMatchesSocket(self, s, info):
diag_msg, attrs = info
family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN)
self.assertEqual(, family)
src, sport = s.getsockname()[0:2]
self.assertEqual(, self.sock_diag.PaddedAddress(src))
self.assertEqual(, sport)
if self.sock_diag.GetDestinationAddress(diag_msg) not in ["", "::"]:
dst, dport = s.getpeername()[0:2]
self.assertEqual(, self.sock_diag.PaddedAddress(dst))
self.assertEqual(, dport)
self.assertRaisesErrno(ENOTCONN, s.getpeername)
mark = s.getsockopt(SOL_SOCKET, net_test.SO_MARK)
self.assertMarkIs(mark, attrs)
def PackAndCheckBytecode(self, instructions):
bytecode = self.sock_diag.PackBytecode(instructions)
decoded = self.sock_diag.DecodeBytecode(bytecode)
self.assertEquals(len(instructions), len(decoded))
self.assertFalse("???" in decoded)
return bytecode
def _EventDuringBlockingCall(self, sock, call, expected_errno, event):
"""Simulates an external event during a blocking call on sock.
sock: The socket to use.
call: A function, the call to make. Takes one parameter, sock.
expected_errno: The value that call is expected to fail with, or None if
call is expected to succeed.
event: A function, the event that will happen during the blocking call.
Takes one parameter, sock.
thread = SocketExceptionThread(sock, call)
if expected_errno is not None:
self.assertTrue(isinstance(thread.exception, IOError),
"Expected IOError, got %s" % thread.exception)
self.assertEqual(expected_errno, thread.exception.errno)
def CloseDuringBlockingCall(self, sock, call, expected_errno):
sock, call, expected_errno,
lambda sock: self.sock_diag.CloseSocketFromFd(sock))
def setUp(self):
super(SockDiagBaseTest, self).setUp()
self.sock_diag = sock_diag.SockDiag()
self.socketpairs = {}
def tearDown(self):
for socketpair in self.socketpairs.values():
for s in socketpair:
super(SockDiagBaseTest, self).tearDown()
class SockDiagTest(SockDiagBaseTest):
def testFindsMappedSockets(self):
"""Tests that inet_diag_find_one_icsk can find mapped sockets."""
socketpair = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM,
for sock in socketpair:
diag_msg = self.sock_diag.FindSockDiagFromFd(sock)
diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
# No errors? Good.
def testFindsAllMySockets(self):
"""Tests that basic socket dumping works."""
self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM)
sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE)
self.assertGreaterEqual(len(sockets), NUM_SOCKETS)
# Find the cookies for all of our sockets.
cookies = {}
for diag_msg, unused_attrs in sockets:
addr = self.sock_diag.GetSourceAddress(diag_msg)
sport =
dport =
if (addr, sport, dport) in self.socketpairs:
cookies[(addr, sport, dport)] =
elif (addr, dport, sport) in self.socketpairs:
cookies[(addr, sport, dport)] =
# Did we find all the cookies?
self.assertEquals(2 * NUM_SOCKETS, len(cookies))
socketpairs = self.socketpairs.values()
for socketpair in socketpairs:
for sock in socketpair:
# Check that we can find a diag_msg by scanning a dump.
cookie = self.sock_diag.FindSockDiagFromFd(sock).id.cookie
# Check that we can find a diag_msg once we know the cookie.
req = self.sock_diag.DiagReqFromSocket(sock) = cookie
info = self.sock_diag.GetSockInfo(req)
self.assertSockInfoMatchesSocket(sock, info)
def testBytecodeCompilation(self):
# pylint: disable=bad-whitespace
instructions = [
(sock_diag.INET_DIAG_BC_S_GE, 1, 8, 0), # 0
(sock_diag.INET_DIAG_BC_D_LE, 1, 7, 0xffff), # 8
(sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::1", 128, -1)), # 16
(sock_diag.INET_DIAG_BC_JMP, 1, 3, None), # 44
(sock_diag.INET_DIAG_BC_S_COND, 2, 4, ("", 32, -1)), # 48
(sock_diag.INET_DIAG_BC_D_LE, 1, 3, 0x6665), # not used # 64
(sock_diag.INET_DIAG_BC_NOP, 1, 1, None), # 72
# 76 acc
# 80 rej
# pylint: enable=bad-whitespace
bytecode = self.PackAndCheckBytecode(instructions)
expected = (
states = 1 << tcp_test.TCP_ESTABLISHED
self.assertMultiLineEqual(expected, bytecode.encode("hex"))
self.assertEquals(76, len(bytecode))
self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM)
filteredsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode,
allsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE,
self.assertItemsEqual(allsockets, filteredsockets)
# Pick a few sockets in hash table order, and check that the bytecode we
# compiled selects them properly.
for socketpair in self.socketpairs.values()[:20]:
for s in socketpair:
diag_msg = self.sock_diag.FindSockDiagFromFd(s)
instructions = [
(sock_diag.INET_DIAG_BC_S_GE, 1, 5,,
(sock_diag.INET_DIAG_BC_S_LE, 1, 4,,
(sock_diag.INET_DIAG_BC_D_GE, 1, 3,,
(sock_diag.INET_DIAG_BC_D_LE, 1, 2,,
bytecode = self.PackAndCheckBytecode(instructions)
self.assertEquals(32, len(bytecode))
sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode)
self.assertEquals(1, len(sockets))
# TODO: why doesn't comparing the cstructs work?
self.assertEquals(diag_msg.Pack(), sockets[0][0].Pack())
def testCrossFamilyBytecode(self):
"""Checks for a cross-family bug in inet_diag_hostcond matching.
Relevant kernel commits:
f67caec inet_diag: avoid unsafe and nonsensical prefix matches in inet_diag_bc_run()
# TODO: this is only here because the test fails if there are any open
# sockets other than the ones it creates itself. Make the bytecode more
# specific and remove it.
states = 1 << tcp_test.TCP_ESTABLISHED
self.assertFalse(self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, "",
unused_pair4 = net_test.CreateSocketPair(AF_INET, SOCK_STREAM, "")
unused_pair6 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, "::1")
bytecode4 = self.PackAndCheckBytecode([
(sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("", 0, -1))])
bytecode6 = self.PackAndCheckBytecode([
(sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::", 0, -1))])
# IPv4/v6 filters must never match IPv6/IPv4 sockets...
v4socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode4,
self.assertTrue(all( == AF_INET for d, _ in v4socks))
v6socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode6,
self.assertTrue(all( == AF_INET6 for d, _ in v6socks))
# Except for mapped addresses, which match both IPv4 and IPv6.
pair5 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM,
diag_msgs = [self.sock_diag.FindSockDiagFromFd(s) for s in pair5]
v4socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
v6socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
self.assertTrue(all(d in v4socks for d in diag_msgs))
self.assertTrue(all(d in v6socks for d in diag_msgs))
def testPortComparisonValidation(self):
"""Checks for a bug in validating port comparison bytecode.
Relevant kernel commits:
5e1f542 inet_diag: validate port comparison byte code to prevent unsafe reads
bytecode = sock_diag.InetDiagBcOp((sock_diag.INET_DIAG_BC_D_GE, 4, 8))
self.sock_diag.DumpAllInetSockets, IPPROTO_TCP, bytecode.Pack())
def testNonSockDiagCommand(self):
def DiagDump(code):
sock_id = self.sock_diag._EmptyInetDiagSockId()
req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff,
self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg, "")
op = sock_diag.SOCK_DIAG_BY_FAMILY
DiagDump(op) # No errors? Good.
self.assertRaisesErrno(EINVAL, DiagDump, op + 17)
def CheckSocketCookie(self, inet, addr):
"""Tests that getsockopt SO_COOKIE can get cookie for all sockets."""
socketpair = net_test.CreateSocketPair(inet, SOCK_STREAM, addr)
for sock in socketpair:
diag_msg = self.sock_diag.FindSockDiagFromFd(sock)
cookie = sock.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8)
self.assertEqual(, cookie)
@unittest.skipUnless(HAVE_SO_COOKIE_SUPPORT, "SO_COOKIE not supported")
def testGetsockoptcookie(self):
self.CheckSocketCookie(AF_INET, "")
self.CheckSocketCookie(AF_INET6, "::1")
class SockDestroyTest(SockDiagBaseTest):
"""Tests that SOCK_DESTROY works correctly.
Relevant kernel commits:
b613f56 net: diag: split inet_diag_dump_one_icsk into two
64be0ae net: diag: Add the ability to destroy a socket.
6eb5d2e net: diag: Support SOCK_DESTROY for inet sockets.
c1e64e2 net: diag: Support destroying TCP sockets.
2010b93 net: tcp: deal with listen sockets properly in tcp_abort.
d48ec88 net: diag: split inet_diag_dump_one_icsk into two
2438189 net: diag: Add the ability to destroy a socket.
7a2ddbc net: diag: Support SOCK_DESTROY for inet sockets.
44047b2 net: diag: Support destroying TCP sockets.
200dae7 net: tcp: deal with listen sockets properly in tcp_abort.
9eaff90 net: diag: split inet_diag_dump_one_icsk into two
d60326c net: diag: Add the ability to destroy a socket.
3d4ce85 net: diag: Support SOCK_DESTROY for inet sockets.
529dfc6 net: diag: Support destroying TCP sockets.
9c712fe net: tcp: deal with listen sockets properly in tcp_abort.
100263d net: diag: split inet_diag_dump_one_icsk into two
194c5f3 net: diag: Add the ability to destroy a socket.
8387ea2 net: diag: Support SOCK_DESTROY for inet sockets.
b80585a net: diag: Support destroying TCP sockets.
476c6ce net: tcp: deal with listen sockets properly in tcp_abort.
56eebf8 net: diag: split inet_diag_dump_one_icsk into two
fb486c9 net: diag: Add the ability to destroy a socket.
0c02b7e net: diag: Support SOCK_DESTROY for inet sockets.
67c71d8 net: diag: Support destroying TCP sockets.
a76e0ec net: tcp: deal with listen sockets properly in tcp_abort.
e6e277b net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
76c83a9 net: diag: split inet_diag_dump_one_icsk into two
f7cf791 net: diag: Add the ability to destroy a socket.
1c42248 net: diag: Support SOCK_DESTROY for inet sockets.
c9e8440d net: diag: Support destroying TCP sockets.
3d9502c tcp: diag: add support for request sockets to tcp_abort()
001cf75 net: tcp: deal with listen sockets properly in tcp_abort.
def testClosesSockets(self):
self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM)
for _, socketpair in self.socketpairs.iteritems():
# Close one of the sockets.
# This will send a RST that will close the other side as well.
s = random.choice(socketpair)
if random.randrange(0, 2) == 1:
diag_msg = self.sock_diag.FindSockDiagFromFd(s)
# Get the cookie wrong and ensure that we get an error and the socket
# is not closed.
real_cookie = = os.urandom(len(real_cookie))
req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
self.assertRaisesErrno(ENOENT, self.sock_diag.CloseSocket, req)
# Now close it with the correct cookie. = real_cookie
# Check that both sockets in the pair are closed.
# Test that killing unix sockets returns EOPNOTSUPP.
class SocketExceptionThread(threading.Thread):
def __init__(self, sock, operation):
self.exception = None
super(SocketExceptionThread, self).__init__()
self.daemon = True
self.sock = sock
self.operation = operation
def run(self):
except (IOError, AssertionError), e:
self.exception = e
class SockDiagTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
def testIpv4MappedSynRecvSocket(self):
"""Tests for the absence of a bug with AF_INET6 TCP SYN-RECV sockets.
Relevant kernel commits:
457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state
netid = random.choice(self.tuns.keys())
self.IncomingConnection(5, tcp_test.TCP_SYN_RECV, netid)
sock_id = self.sock_diag._EmptyInetDiagSockId() = self.port
states = 1 << tcp_test.TCP_SYN_RECV
req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id))
children = self.sock_diag.Dump(req, NO_BYTECODE)
for child, unused_args in children:
self.assertEqual(tcp_test.TCP_SYN_RECV, child.state)
class SockDestroyTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
def setUp(self):
super(SockDestroyTcpTest, self).setUp()
self.netid = random.choice(self.tuns.keys())
def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True):
"""Closes the socket and checks whether a RST is sent or not."""
if sock is not None:
self.assertIsNone(req, "Must specify sock or req, not both")
self.assertRaisesErrno(EINVAL, sock.accept)
self.assertIsNone(sock, "Must specify sock or req, not both")
if expect_reset:
desc, rst = self.RstPacket()
msg = "%s: expecting %s: " % (msg, desc)
self.ExpectPacketOn(self.netid, msg, rst)
msg = "%s: " % msg
self.ExpectNoPacketsOn(self.netid, msg)
if sock is not None and do_close:
def CheckTcpReset(self, state, statename):
for version in [4, 5, 6]:
msg = "Closing incoming IPv%d %s socket" % (version, statename)
self.IncomingConnection(version, state, self.netid)
self.CheckRstOnClose(self.s, None, False, msg)
if state != tcp_test.TCP_LISTEN:
msg = "Closing accepted IPv%d %s socket" % (version, statename)
self.CheckRstOnClose(self.accepted, None, True, msg)
def testTcpResets(self):
"""Checks that closing sockets in appropriate states sends a RST."""
self.CheckTcpReset(tcp_test.TCP_LISTEN, "TCP_LISTEN")
self.CheckTcpReset(tcp_test.TCP_ESTABLISHED, "TCP_ESTABLISHED")
self.CheckTcpReset(tcp_test.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT")
def testFinWait1Socket(self):
for version in [4, 5, 6]:
self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
# Get the cookie so we can find this socket after we close it.
diag_msg = self.sock_diag.FindSockDiagFromFd(self.accepted)
diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
# Close the socket and check that it goes into FIN_WAIT1 and sends a FIN.
diag_req.states = 1 << tcp_test.TCP_FIN_WAIT1
diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req)
self.assertEquals(tcp_test.TCP_FIN_WAIT1, diag_msg.state)
desc, fin = self.FinPacket()
self.ExpectPacketOn(self.netid, "Closing FIN_WAIT1 socket", fin)
# Destroy the socket and expect no RST.
self.CheckRstOnClose(None, diag_req, False, "Closing FIN_WAIT1 socket")
diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req)
# The socket is still there in FIN_WAIT1: SOCK_DESTROY did nothing
# because userspace had already closed it.
self.assertEquals(tcp_test.TCP_FIN_WAIT1, diag_msg.state)
# ACK the FIN so we don't trip over retransmits in future tests.
finversion = 4 if version == 5 else version
desc, finack = packets.ACK(finversion, self.remoteaddr, self.myaddr, fin)
diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req)
self.ReceivePacketOn(self.netid, finack)
# See if we can find the resulting FIN_WAIT2 socket. This does not appear
# to work on 3.10.
if net_test.LINUX_VERSION >= (3, 18):
diag_req.states = 1 << tcp_test.TCP_FIN_WAIT2
infos = self.sock_diag.Dump(diag_req, "")
self.assertTrue(any(diag_msg.state == tcp_test.TCP_FIN_WAIT2
for diag_msg, attrs in infos),
"Expected to find FIN_WAIT2 socket in %s" % infos)
def FindChildSockets(self, s):
"""Finds the SYN_RECV child sockets of a given listening socket."""
d = self.sock_diag.FindSockDiagFromFd(self.s)
req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
req.states = 1 << tcp_test.TCP_SYN_RECV | 1 << tcp_test.TCP_ESTABLISHED = "\x00" * 8
bad_bytecode = self.PackAndCheckBytecode(
[(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (0xffff, 0xffff))])
self.assertEqual([], self.sock_diag.Dump(req, bad_bytecode))
bytecode = self.PackAndCheckBytecode(
[(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (self.netid, 0xffff))])
children = self.sock_diag.Dump(req, bytecode)
return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
for d, _ in children]
def CheckChildSocket(self, version, statename, parent_first):
state = getattr(tcp_test, statename)
self.IncomingConnection(version, state, self.netid)
d = self.sock_diag.FindSockDiagFromFd(self.s)
parent = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
children = self.FindChildSockets(self.s)
self.assertEquals(1, len(children))
is_established = (state == tcp_test.TCP_NOT_YET_ACCEPTED)
expected_state = tcp_test.TCP_ESTABLISHED if is_established else state
# The new TCP listener code in 4.4 makes SYN_RECV sockets live in the
# regular TCP hash tables, and inet_diag_find_one_icsk can find them.
# Before 4.4, we can see those sockets in dumps, but we can't fetch
# or close them.
can_close_children = is_established or net_test.LINUX_VERSION >= (4, 4)
for child in children:
if can_close_children:
diag_msg, attrs = self.sock_diag.GetSockInfo(child)
self.assertEquals(diag_msg.state, expected_state)
self.assertMarkIs(self.netid, attrs)
self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child)
def CloseParent(expect_reset):
msg = "Closing parent IPv%d %s socket %s child" % (
version, statename, "before" if parent_first else "after")
self.CheckRstOnClose(self.s, None, expect_reset, msg)
self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, parent)
def CheckChildrenClosed():
for child in children:
self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child)
def CloseChildren():
for child in children:
msg = "Closing child IPv%d %s socket %s parent" % (
version, statename, "after" if parent_first else "before")
self.CheckRstOnClose(None, child, is_established, msg)
self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child)
if parent_first:
# Closing the parent will close child sockets, which will send a RST,
# iff they are already established.
if is_established:
elif can_close_children:
if can_close_children:
def testChildSockets(self):
for version in [4, 5, 6]:
self.CheckChildSocket(version, "TCP_SYN_RECV", False)
self.CheckChildSocket(version, "TCP_SYN_RECV", True)
self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", False)
self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", True)
def testAcceptInterrupted(self):
"""Tests that accept() is interrupted by SOCK_DESTROY."""
for version in [4, 5, 6]:
self.IncomingConnection(version, tcp_test.TCP_LISTEN, self.netid)
self.assertRaisesErrno(ENOTCONN, self.s.recv, 4096)
self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL)
self.assertRaisesErrno(ECONNABORTED, self.s.send, "foo")
self.assertRaisesErrno(EINVAL, self.s.accept)
# TODO: this should really return an error such as ENOTCONN...
self.assertEquals("", self.s.recv(4096))
def testReadInterrupted(self):
"""Tests that read() is interrupted by SOCK_DESTROY."""
for version in [4, 5, 6]:
self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096),
# Writing returns EPIPE, and reading returns EOF.
self.assertRaisesErrno(EPIPE, self.accepted.send, "foo")
self.assertEquals("", self.accepted.recv(4096))
self.assertEquals("", self.accepted.recv(4096))
def testConnectInterrupted(self):
"""Tests that connect() is interrupted by SOCK_DESTROY."""
for version in [4, 5, 6]:
family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
self.SelectInterface(s, self.netid, "mark")
remotesockaddr = self.GetRemoteSocketAddress(version)
remoteaddr = self.GetRemoteAddress(version)
s.bind(("", 0))
_, sport = s.getsockname()[:2]
s, lambda sock: sock.connect((remotesockaddr, 53)), ECONNABORTED)
desc, syn = packets.SYN(53, version, self.MyAddress(version, self.netid),
remoteaddr, sport=sport, seq=None)
self.ExpectPacketOn(self.netid, desc, syn)
msg = "SOCK_DESTROY of socket in connect, expected no RST"
self.ExpectNoPacketsOn(self.netid, msg)
class PollOnCloseTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
"""Tests that the effect of SOCK_DESTROY on poll matches TCP RSTs.
The behaviour of poll() in these cases is not what we might expect: if only
POLLIN is specified, it will return POLLIN|POLLERR|POLLHUP, but if POLLOUT
is (also) specified, it will only return POLLOUT.
POLLIN_ERR_HUP = select.POLLIN | select.POLLERR | select.POLLHUP
def setUp(self):
super(PollOnCloseTest, self).setUp()
self.netid = random.choice(self.tuns.keys())
POLL_FLAGS = [(select.POLLIN, "IN"), (select.POLLOUT, "OUT"),
(select.POLLERR, "ERR"), (select.POLLHUP, "HUP")]
def PollResultToString(self, poll_events, ignoremask):
out = []
for fd, event in poll_events:
flags = [name for (flag, name) in self.POLL_FLAGS
if event & flag & ~ignoremask != 0]
out.append((fd, "|".join(flags)))
return out
def BlockingPoll(self, sock, mask, expected, ignoremask):
p = select.poll()
p.register(sock, mask)
expected_fds = [(sock.fileno(), expected)]
# Don't block forever or we'll hang continuous test runs on failure.
# A 5-second timeout should be long enough not to be flaky.
actual_fds = p.poll(5000)
self.assertEqual(self.PollResultToString(expected_fds, ignoremask),
self.PollResultToString(actual_fds, ignoremask))
def RstDuringBlockingCall(self, sock, call, expected_errno):
sock, call, expected_errno,
lambda _: self.ReceiveRstPacketOn(self.netid))
def assertSocketErrors(self, errno):
# The first operation returns the expected errno.
self.assertRaisesErrno(errno, self.accepted.recv, 4096)
# Subsequent operations behave as normal.
self.assertRaisesErrno(EPIPE, self.accepted.send, "foo")
self.assertEquals("", self.accepted.recv(4096))
self.assertEquals("", self.accepted.recv(4096))
def CheckPollDestroy(self, mask, expected, ignoremask):
"""Interrupts a poll() with SOCK_DESTROY."""
for version in [4, 5, 6]:
self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
lambda sock: self.BlockingPoll(sock, mask, expected, ignoremask),
def CheckPollRst(self, mask, expected, ignoremask):
"""Interrupts a poll() by receiving a TCP RST."""
for version in [4, 5, 6]:
self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
lambda sock: self.BlockingPoll(sock, mask, expected, ignoremask),
def testReadPollRst(self):
# Until 3d4762639d ("tcp: remove poll() flakes when receiving RST"), poll()
# would sometimes return POLLERR and sometimes POLLIN|POLLERR|POLLHUP. This
# is due to a race inside the kernel and thus is not visible on the VM, only
# on physical hardware.
if net_test.LINUX_VERSION < (4, 14, 0):
ignoremask = select.POLLIN | select.POLLHUP
ignoremask = 0
self.CheckPollRst(select.POLLIN, self.POLLIN_ERR_HUP, ignoremask)
def testWritePollRst(self):
self.CheckPollRst(select.POLLOUT, select.POLLOUT, 0)
def testReadWritePollRst(self):
self.CheckPollRst(self.POLLIN_OUT, select.POLLOUT, 0)
def testReadPollDestroy(self):
# tcp_abort has the same race that tcp_reset has, but it's not fixed yet.
ignoremask = select.POLLIN | select.POLLHUP
self.CheckPollDestroy(select.POLLIN, self.POLLIN_ERR_HUP, ignoremask)
def testWritePollDestroy(self):
self.CheckPollDestroy(select.POLLOUT, select.POLLOUT, 0)
def testReadWritePollDestroy(self):
self.CheckPollDestroy(self.POLLIN_OUT, select.POLLOUT, 0)
@unittest.skipUnless(HAVE_UDP_DIAG, "INET_UDP_DIAG not enabled")
class SockDestroyUdpTest(SockDiagBaseTest):
"""Tests SOCK_DESTROY on UDP sockets.
Relevant kernel commits:
upstream net-next:
5d77dca net: diag: support SOCK_DESTROY for UDP sockets
f95bf34 net: diag: make udp_diag_destroy work for mapped addresses.
def testClosesUdpSockets(self):
self.socketpairs = self._CreateLotsOfSockets(SOCK_DGRAM)
for _, socketpair in self.socketpairs.iteritems():
s1, s2 = socketpair
def BindToRandomPort(self, s, addr):
for i in xrange(20):
port = random.randrange(1024, 65535)
s.bind((addr, port))
return port
except error, e:
if e.errno != EADDRINUSE:
raise e
raise ValueError("Could not find a free port on %s after %d attempts" %
(addr, ATTEMPTS))
def testSocketAddressesAfterClose(self):
for version in 4, 5, 6:
netid = random.choice(self.NETIDS)
dst = self.GetRemoteSocketAddress(version)
family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
unspec = {4: "", 5: "::", 6: "::"}[version]
# Closing a socket that was not explicitly bound (i.e., bound via
# connect(), not bind()) clears the source address and port.
s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
self.SelectInterface(s, netid, "mark")
s.connect((dst, 53))
self.assertEqual((unspec, 0), s.getsockname()[:2])
# Closing a socket bound to an IP address leaves the address as is.
s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
src = self.MySocketAddress(version, netid)
s.bind((src, 0))
s.connect((dst, 53))
port = s.getsockname()[1]
self.assertEqual((src, 0), s.getsockname()[:2])
# Closing a socket bound to a port leaves the port as is.
s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
port = self.BindToRandomPort(s, "")
s.connect((dst, 53))
self.assertEqual((unspec, port), s.getsockname()[:2])
# Closing a socket bound to IP address and port leaves both as is.
s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
src = self.MySocketAddress(version, netid)
port = self.BindToRandomPort(s, src)
self.assertEqual((src, port), s.getsockname()[:2])
def testReadInterrupted(self):
"""Tests that read() is interrupted by SOCK_DESTROY."""
for version in [4, 5, 6]:
family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
s = net_test.UDPSocket(family)
self.SelectInterface(s, random.choice(self.NETIDS), "mark")
addr = self.GetRemoteAddress(version)
# Check that reads on connected sockets are interrupted.
s.connect((addr, 53))
self.assertEquals(3, s.send("foo"))
self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096),
# A destroyed socket is no longer connected, but still usable.
self.assertRaisesErrno(EDESTADDRREQ, s.send, "foo")
self.assertEquals(3, s.sendto("foo", (addr, 53)))
# Check that reads on unconnected sockets are also interrupted.
self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096),
class SockDestroyPermissionTest(SockDiagBaseTest):
def CheckPermissions(self, socktype):
s = socket(AF_INET6, socktype, 0)
self.SelectInterface(s, random.choice(self.NETIDS), "mark")
if socktype == SOCK_STREAM:
expectedstate = tcp_test.TCP_LISTEN
s.connect((self.GetRemoteAddress(6), 53))
expectedstate = tcp_test.TCP_ESTABLISHED
with net_test.RunAsUid(12345):
EPERM, self.sock_diag.CloseSocketFromFd, s)
self.assertRaises(ValueError, self.sock_diag.CloseSocketFromFd, s)
@unittest.skipUnless(HAVE_UDP_DIAG, "INET_UDP_DIAG not enabled")
def testUdp(self):
def testTcp(self):
class SockDiagMarkTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
"""Tests SOCK_DIAG bytecode filters that use marks.
Relevant kernel commits:
upstream net-next:
627cc4a net: diag: slightly refactor the inet_diag_bc_audit error checks.
a52e95a net: diag: allow socket bytecode filters to match socket marks
d545cac net: inet: diag: expose the socket mark to privileged processes.
def FilterEstablishedSockets(self, mark, mask):
instructions = [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (mark, mask))]
bytecode = self.sock_diag.PackBytecode(instructions)
return self.sock_diag.DumpAllInetSockets(
IPPROTO_TCP, bytecode, states=(1 << tcp_test.TCP_ESTABLISHED))
def assertSamePorts(self, ports, diag_msgs):
expected = sorted(ports)
actual = sorted([msg[0] for msg in diag_msgs])
self.assertEquals(expected, actual)
def SockInfoMatchesSocket(self, s, info):
self.assertSockInfoMatchesSocket(s, info)
return True
except AssertionError:
return False
def SocketDescription(s):
return "%s -> %s" % (str(s.getsockname()), str(s.getpeername()))
def assertFoundSockets(self, infos, sockets):
matches = {}
for s in sockets:
match = None
for info in infos:
if self.SockInfoMatchesSocket(s, info):
if match:"Socket %s matched both %s and %s" %
(self.SocketDescription(s), match, info))
matches[s] = info
self.assertTrue(s in matches, "Did not find socket %s in dump" %
for i in infos:
if i not in matches.values():"Too many sockets in dump, first unexpected: %s" % str(i))
def testMarkBytecode(self):
family, addr = random.choice([
(AF_INET, ""),
(AF_INET6, "::1"),
(AF_INET6, "::ffff:")])
s1, s2 = net_test.CreateSocketPair(family, SOCK_STREAM, addr)
s1.setsockopt(SOL_SOCKET, net_test.SO_MARK, 0xfff1234)
s2.setsockopt(SOL_SOCKET, net_test.SO_MARK, 0xf0f1235)
infos = self.FilterEstablishedSockets(0x1234, 0xffff)
self.assertFoundSockets(infos, [s1])
infos = self.FilterEstablishedSockets(0x1234, 0xfffe)
self.assertFoundSockets(infos, [s1, s2])
infos = self.FilterEstablishedSockets(0x1235, 0xffff)
self.assertFoundSockets(infos, [s2])
infos = self.FilterEstablishedSockets(0x0, 0x0)
self.assertFoundSockets(infos, [s1, s2])
infos = self.FilterEstablishedSockets(0xfff0000, 0xf0fed00)
self.assertEquals(0, len(infos))
with net_test.RunAsUid(12345):
self.assertRaisesErrno(EPERM, self.FilterEstablishedSockets,
0xfff0000, 0xf0fed00)
def SetRandomMark(s):
# Python doesn't like marks that don't fit into a signed int.
mark = random.randrange(0, 2**31 - 1)
s.setsockopt(SOL_SOCKET, net_test.SO_MARK, mark)
return mark
def assertSocketMarkIs(self, s, mark):
diag_msg, attrs = self.sock_diag.FindSockInfoFromFd(s)
self.assertMarkIs(mark, attrs)
with net_test.RunAsUid(12345):
diag_msg, attrs = self.sock_diag.FindSockInfoFromFd(s)
self.assertMarkIs(None, attrs)
def testMarkInAttributes(self):
testcases = [(AF_INET, ""),
(AF_INET6, "::1"),
(AF_INET6, "::ffff:")]
for family, addr in testcases:
# TCP listen sockets.
server = socket(family, SOCK_STREAM, 0)
server.bind((addr, 0))
port = server.getsockname()[1]
server.listen(1) # Or the socket won't be in the hashtables.
server_mark = self.SetRandomMark(server)
self.assertSocketMarkIs(server, server_mark)
# TCP client sockets.
client = socket(family, SOCK_STREAM, 0)
client_mark = self.SetRandomMark(client)
client.connect((addr, port))
self.assertSocketMarkIs(client, client_mark)
# TCP server sockets.
accepted, _ = server.accept()
self.assertSocketMarkIs(accepted, server_mark)
accepted_mark = self.SetRandomMark(accepted)
self.assertSocketMarkIs(accepted, accepted_mark)
self.assertSocketMarkIs(server, server_mark)
# Other TCP states are tested in SockDestroyTcpTest.
# UDP sockets.
s = socket(family, SOCK_DGRAM, 0)
mark = self.SetRandomMark(s)
s.connect(("", 53))
self.assertSocketMarkIs(s, mark)
# Basic test for SCTP. sctp_diag was only added in 4.7.
s = socket(family, SOCK_STREAM, IPPROTO_SCTP)
s.bind((addr, 0))
mark = self.SetRandomMark(s)
self.assertSocketMarkIs(s, mark)
sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_SCTP, NO_BYTECODE)
self.assertEqual(1, len(sockets))
self.assertEqual(mark, sockets[0][1].get("INET_DIAG_MARK", None))
if __name__ == "__main__":