blob: 2d466abb15255e30ba05e8de87e5dee8891027f3 [file] [log] [blame]
# Copyright (C) 2020 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:#www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Connection that supports our pickling scheme"""
import os
import logging
import select
from errno import ESOCKTNOSUPPORT, ENOTCONN
from array import array
import socket as socket_module
from socket import (
AF_UNIX,
CMSG_LEN,
MSG_CTRUNC,
MSG_TRUNC,
SCM_RIGHTS,
SHUT_RD,
SHUT_WR,
SOCK_DGRAM,
SOCK_STREAM,
SOL_SOCKET,
socket,
socketpair,
)
from .util import (
FdHolder,
SafeClosingObject,
cached_property,
die_due_to_fatal_exception,
the,
)
from .reduction import (
MessageTooLargeError,
PickleContext,
UnpickleContext,
fancy_pickle,
fancy_unpickle,
)
log = logging.getLogger(__name__)
MSG_CMSG_CLOEXEC = getattr(socket_module, "MSG_CMSG_CLOEXEC", 0)
SOCK_SEQPACKET = getattr(socket_module, "SOCK_SEQPACKET", 0)
MSG_DONTWAIT = getattr(socket_module, "MSG_DONTWAIT", 0)
MSG_NOSIGNAL = getattr(socket_module, "MSG_NOSIGNAL", 0)
if __debug__:
if os.getenv("MODERNMP_SIMULATE_MACOS") == "1":
MSG_NOSIGNAL = 0
MSG_DONTWAIT = 0
SOCK_SEQPACKET = 0
# AF_UNIX SOCK_DGRAM notes:
#
# Say we have sockets A and B connected to each other. What happens
# when A tries to send and B is dead? We can get any of ECONNREFUSED
# (first send after B closes), ENOTCONN (subsequent sends from B), or
# EPIPE (B shuts down reading, but doesn't close). They all mean the
# same thing though: the peer is dead.
#
# We always use born-connected socketpairs, so we don't have to worry
# about endpoint binding.
#
# Maximum size, in bytes, of a datagram message. Messages larger than
# this size get sent indirectly via a file descriptor that we send
# over the datagram.
MAX_MSG_SIZE = 64*1024
# The maximum number of file descriptors in a single message.
MAX_MSG_FD = 1024
# The maximum number of bytes we need to hold MAX_MSG_FD file descriptors
MSG_ANC_SIZE = CMSG_LEN(4 * MAX_MSG_FD)
RECVMSG_BLOCK_FLAGS = MSG_CMSG_CLOEXEC
RECVMSG_NONBLOCK_FLAGS = RECVMSG_BLOCK_FLAGS | MSG_DONTWAIT
SENDMSG_BLOCK_FLAGS = MSG_NOSIGNAL
SENDMSG_NONBLOCK_FLAGS = SENDMSG_BLOCK_FLAGS | MSG_DONTWAIT
RECV_FLAGS_TRUNCATED = MSG_TRUNC | MSG_CTRUNC
class PeerDiedException(Exception):
"""Exception raised when we detect that our peer has died"""
def _extract_fds_from_ancdata(ancdata):
try:
raw_fds = array("i")
for cmsg_level, cmsg_type, cmsg_data in ancdata:
if cmsg_level == SOL_SOCKET and cmsg_type == SCM_RIGHTS:
raw_fds.frombytes(cmsg_data[:len(cmsg_data) -
(len(cmsg_data) % raw_fds.itemsize)])
fds = [FdHolder.steal(raw_fd) for raw_fd in raw_fds]
except:
# If we fail here, we leak the SCM_RIGHTS FDs with no way to
# recover, so just die.
die_due_to_fatal_exception("slurping ancdata")
if not MSG_CMSG_CLOEXEC:
for fd in fds:
os.set_inheritable(fd.fileno(), False)
return fds
def fdhs_to_ancdata(fdhs):
"""Convert a sequence of FdHolder objects to an ancdata array"""
if not fdhs:
return ()
if len(fdhs) > MAX_MSG_FD:
raise MessageTooLargeError("too many FDs: want {!r}; max is {!r}"
.format(len(fdhs), MAX_MSG_FD))
return [(SOL_SOCKET,
SCM_RIGHTS,
array("i", (fdh.fileno() for fdh in fdhs)))]
class MessageSendContext(PickleContext):
"""Extends a PickleContext with message send commit stuff"""
def sendmsg(self,
pickle_data: bytes,
block: bool,
channel: "Channel") -> None:
"""Send a message to the desired target.
Subclasses can override this method to perform actions before and
after message sends. The implementation should call sendmsg_raw
on CHANNEL or something with an equivalent effect.
"""
channel.sendmsg_raw(pickle_data, fdhs_to_ancdata(self.fdhs), block)
class Channel(SafeClosingObject):
"""General-purpose object-based intramachine socket connection
This class does no internal locking. You must make sure that at most
one reader and one writer can be using the Channel across all
processes.
"""
@classmethod
def __make_pair_1(cls):
candidates = [ChannelDgram.make_pair_1]
if (os.getenv("MODERNMP_FORCE_DGRAM") != "1" and
SOCK_SEQPACKET and
MSG_DONTWAIT):
candidates.append(ChannelSeqPacket.make_pair_1)
while candidates:
candidate = candidates.pop()
try:
channels = candidate()
except OSError as ex:
if ex.errno != ESOCKTNOSUPPORT:
raise
else:
cls.__make_pair_1 = candidate
return channels
raise RuntimeError("no way to create sockets")
@classmethod
def make_pair(cls, duplex: bool = True):
"""Make a pair of connected channels
Return a pair (CHANNEL_A, CHANNEL_B).
If DUPLEX is false, the two channels are indistinguishable.
If DUPLEX is true, data can flow only from CHANNEL_B to CHANNEL_A.
"""
channels = cls.__make_pair_1()
if not duplex:
channels[0].shutdown(SHUT_WR)
channels[1].shutdown(SHUT_RD)
return channels
def __init__(self, sock):
"""Private constructor: use the make_pair() constructor"""
assert isinstance(sock, socket)
assert sock.family == AF_UNIX
assert sock.type in (SOCK_SEQPACKET, SOCK_DGRAM)
self._socket = sock
def __getstate__(self):
cacheless_dict = self.__dict__.copy()
for name, value in self.__class__.__dict__.items():
if isinstance(value, cached_property):
cacheless_dict.pop(name, None)
return cacheless_dict
@cached_property
def fd_numbers(self):
"""Tuple of all file descriptors that we care about"""
return self._socket.fileno(),
@property
def data_socket(self):
"""The data socket, for commands"""
return self._socket
def sendmsg_raw(self,
pickle_data: bytes,
ancdata: bytes,
block: bool) -> None:
"""Send a message"""
if __debug__:
self.__check_blocking(block)
flags = SENDMSG_BLOCK_FLAGS if block else SENDMSG_NONBLOCK_FLAGS
ret = self._socket.sendmsg([pickle_data], ancdata, flags)
assert ret == len(pickle_data)
def poll_for_send(self):
"""Wait until the channel becomes ready for send"""
return self._send_poller.poll()
@cached_property
def _send_poller(self):
poll = select.poll()
poll.register(self._socket, select.POLLOUT)
return poll
def send(self,
obj,
make_message_send_context=MessageSendContext,
*,
block=True):
"""Send an object through this connection
OBJ is the object to send, which we'll send with fancy_pickle.
If the pickling generates file descriptors, send them along with the
message using SCM_RIGHTS.
BLOCK controls whether the send blocks.
MAKE_MESSAGE_SEND_CONTEXT is a function of no arguments that
returns a MessageSendContext object, which also becomes the
pickle context.
If the target ceases to exist while we're sending the message, raise
PeerDiedException. (We fold socket errors and resource manager
errors into this one exception type.) All other exceptions are
serious internal errors.
Return true if we sent a message and false if we successfully
failed to send a message. The latter case occurs only when BLOCK
is false and sending the message would have blocked.
"""
pickle_data, pc = fancy_pickle(
obj,
size_limit=MAX_MSG_SIZE,
make_pickle_context=make_message_send_context)
assert isinstance(pc, MessageSendContext)
assert len(pickle_data) <= MAX_MSG_SIZE
try:
pc.sendmsg(pickle_data, block, self)
except BlockingIOError:
assert not block
return False
except ConnectionError:
raise PeerDiedException
except OSError as ex:
if ex.errno == ENOTCONN:
raise PeerDiedException
raise
return True
def __check_blocking(self, block):
if __debug__:
fd_is_blocking = os.get_blocking(self._socket.fileno())
if block:
assert fd_is_blocking
else:
assert MSG_DONTWAIT or not fd_is_blocking
def _recvmsg(self, data_buf, block):
if __debug__:
self.__check_blocking(block)
nbytes, ancdata, recv_flags, _address = \
self._socket.recvmsg_into([data_buf],
MSG_ANC_SIZE,
(RECVMSG_BLOCK_FLAGS
if block else RECVMSG_NONBLOCK_FLAGS))
if recv_flags & RECV_FLAGS_TRUNCATED:
raise RuntimeError("incoming message truncated")
return nbytes, ancdata
def recv(self, make_unpickle_context=UnpickleContext, *, block=True):
"""Receive an object sent with send()
BLOCK controls whether the receive blocks.
Impotant! If BLOCK is false and reading a message would block, we
raise BlockingIOError! You might think we could return a sentinel
(like None) in this case, but if we did that, we couldn't
distinguish receiving the sentinel from failing to read a message.
If the other end of the stream is disconnected, raise
PeerDiedException. Raise BlockingIOError (N.B. not return None!)
if BLOCK is true and we can't read a message right now. If BLOCK is
false, block until we get a message or until some other error
occurs.
"""
data_buf = bytearray(MAX_MSG_SIZE) # TODO(dancol): POOL!
try:
nbytes, ancdata = self._recvmsg(data_buf, block)
except ConnectionError:
raise PeerDiedException
if not nbytes:
raise PeerDiedException
return fancy_unpickle(
memoryview(data_buf)[:nbytes],
_extract_fds_from_ancdata(ancdata) if ancdata else (),
make_unpickle_context)
def shutdown(self, direction):
"""Shut down the connection in one direction
DIRECTION is SHUT_RD or SHUT_WR."""
assert direction in (SHUT_RD, SHUT_WR)
self._socket.shutdown(direction)
def _do_close(self):
self._socket.close()
super()._do_close()
def __repr__(self):
return "<{} {!r}>".format(type(self), self._socket)
class ChannelDgram(Channel):
"""SOCK_SEQPACKET emulating channel based on SOCK_DGRAM"""
@classmethod
def make_pair_1(cls):
"""Implementation of make_pair"""
socket_a, socket_b = socketpair(AF_UNIX, SOCK_DGRAM)
canary_a, canary_b = socketpair(AF_UNIX, SOCK_STREAM)
return cls(socket_a, canary_a), cls(socket_b, canary_b)
def __init__(self, sock, canary):
super().__init__(sock)
self.__canary = the(socket, canary)
for fd_number in self.fd_numbers:
os.set_blocking(fd_number, False) # Can't rely on MSG_DONTWAIT
@cached_property
def fd_numbers(self):
numbers = self._socket.fileno(), self.__canary.fileno()
return numbers
@cached_property
def __recv_poller(self):
poll = select.poll()
for thing in self._socket, self.__canary:
poll.register(thing.fileno(), select.POLLIN)
return poll
def _recvmsg(self, data_buf, block):
# Emulate blocking IO with a poll loop so we can watch for canary
# disconnect too.
while True:
try:
return super()._recvmsg(data_buf, False)
except BlockingIOError:
pass
if not block:
self.__canary.recv(1) # Will raise BlockingIOError...
raise PeerDiedException # ...unless peer has died
else:
try:
self.__canary.recv(1)
except BlockingIOError:
self.__recv_poller.poll()
else:
raise PeerDiedException
def sendmsg_raw(self, pickle_data, ancdata, block):
if not block:
super().sendmsg_raw(pickle_data, ancdata, block)
else:
while True: # Emulate blocking mode
try:
super().sendmsg_raw(pickle_data, ancdata, False)
except BlockingIOError:
pass
self._send_poller.poll()
def shutdown(self, direction):
super().shutdown(direction)
self.__canary.shutdown(direction)
def _do_close(self):
self.__canary.close()
super()._do_close()
def __repr__(self):
return "<{} {!r} {!r}>".format(type(self), self._socket, self.__canary)
class ChannelSeqPacket(Channel):
"""Channel using SOCK_SEQPACKET"""
@classmethod
def make_pair_1(cls):
"""Implement make_pair for seqpacket connnections"""
socket_a, socket_b = socketpair(AF_UNIX, SOCK_SEQPACKET)
return cls(socket_a), cls(socket_b)