| # Copyright (C) 2020 The Android Open Source Project |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http:#www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """Connection that supports our pickling scheme""" |
| |
| import os |
| import logging |
| import select |
| from errno import ESOCKTNOSUPPORT, ENOTCONN |
| from array import array |
| |
| import socket as socket_module |
| from socket import ( |
| AF_UNIX, |
| CMSG_LEN, |
| MSG_CTRUNC, |
| MSG_TRUNC, |
| SCM_RIGHTS, |
| SHUT_RD, |
| SHUT_WR, |
| SOCK_DGRAM, |
| SOCK_STREAM, |
| SOL_SOCKET, |
| socket, |
| socketpair, |
| ) |
| |
| from .util import ( |
| FdHolder, |
| SafeClosingObject, |
| cached_property, |
| die_due_to_fatal_exception, |
| the, |
| ) |
| |
| from .reduction import ( |
| MessageTooLargeError, |
| PickleContext, |
| UnpickleContext, |
| fancy_pickle, |
| fancy_unpickle, |
| ) |
| |
| log = logging.getLogger(__name__) |
| |
| MSG_CMSG_CLOEXEC = getattr(socket_module, "MSG_CMSG_CLOEXEC", 0) |
| SOCK_SEQPACKET = getattr(socket_module, "SOCK_SEQPACKET", 0) |
| MSG_DONTWAIT = getattr(socket_module, "MSG_DONTWAIT", 0) |
| MSG_NOSIGNAL = getattr(socket_module, "MSG_NOSIGNAL", 0) |
| |
| if __debug__: |
| if os.getenv("MODERNMP_SIMULATE_MACOS") == "1": |
| MSG_NOSIGNAL = 0 |
| MSG_DONTWAIT = 0 |
| SOCK_SEQPACKET = 0 |
| |
| # AF_UNIX SOCK_DGRAM notes: |
| # |
| # Say we have sockets A and B connected to each other. What happens |
| # when A tries to send and B is dead? We can get any of ECONNREFUSED |
| # (first send after B closes), ENOTCONN (subsequent sends from B), or |
| # EPIPE (B shuts down reading, but doesn't close). They all mean the |
| # same thing though: the peer is dead. |
| # |
| # We always use born-connected socketpairs, so we don't have to worry |
| # about endpoint binding. |
| # |
| |
| # Maximum size, in bytes, of a datagram message. Messages larger than |
| # this size get sent indirectly via a file descriptor that we send |
| # over the datagram. |
| MAX_MSG_SIZE = 64*1024 |
| |
| # The maximum number of file descriptors in a single message. |
| MAX_MSG_FD = 1024 |
| |
| # The maximum number of bytes we need to hold MAX_MSG_FD file descriptors |
| MSG_ANC_SIZE = CMSG_LEN(4 * MAX_MSG_FD) |
| |
| RECVMSG_BLOCK_FLAGS = MSG_CMSG_CLOEXEC |
| RECVMSG_NONBLOCK_FLAGS = RECVMSG_BLOCK_FLAGS | MSG_DONTWAIT |
| |
| SENDMSG_BLOCK_FLAGS = MSG_NOSIGNAL |
| SENDMSG_NONBLOCK_FLAGS = SENDMSG_BLOCK_FLAGS | MSG_DONTWAIT |
| |
| RECV_FLAGS_TRUNCATED = MSG_TRUNC | MSG_CTRUNC |
| |
| class PeerDiedException(Exception): |
| """Exception raised when we detect that our peer has died""" |
| |
| def _extract_fds_from_ancdata(ancdata): |
| try: |
| raw_fds = array("i") |
| for cmsg_level, cmsg_type, cmsg_data in ancdata: |
| if cmsg_level == SOL_SOCKET and cmsg_type == SCM_RIGHTS: |
| raw_fds.frombytes(cmsg_data[:len(cmsg_data) - |
| (len(cmsg_data) % raw_fds.itemsize)]) |
| fds = [FdHolder.steal(raw_fd) for raw_fd in raw_fds] |
| except: |
| # If we fail here, we leak the SCM_RIGHTS FDs with no way to |
| # recover, so just die. |
| die_due_to_fatal_exception("slurping ancdata") |
| if not MSG_CMSG_CLOEXEC: |
| for fd in fds: |
| os.set_inheritable(fd.fileno(), False) |
| return fds |
| |
| def fdhs_to_ancdata(fdhs): |
| """Convert a sequence of FdHolder objects to an ancdata array""" |
| |
| if not fdhs: |
| return () |
| if len(fdhs) > MAX_MSG_FD: |
| raise MessageTooLargeError("too many FDs: want {!r}; max is {!r}" |
| .format(len(fdhs), MAX_MSG_FD)) |
| return [(SOL_SOCKET, |
| SCM_RIGHTS, |
| array("i", (fdh.fileno() for fdh in fdhs)))] |
| |
| class MessageSendContext(PickleContext): |
| """Extends a PickleContext with message send commit stuff""" |
| |
| def sendmsg(self, |
| pickle_data: bytes, |
| block: bool, |
| channel: "Channel") -> None: |
| """Send a message to the desired target. |
| |
| Subclasses can override this method to perform actions before and |
| after message sends. The implementation should call sendmsg_raw |
| on CHANNEL or something with an equivalent effect. |
| """ |
| channel.sendmsg_raw(pickle_data, fdhs_to_ancdata(self.fdhs), block) |
| |
| class Channel(SafeClosingObject): |
| """General-purpose object-based intramachine socket connection |
| |
| This class does no internal locking. You must make sure that at most |
| one reader and one writer can be using the Channel across all |
| processes. |
| """ |
| |
| @classmethod |
| def __make_pair_1(cls): |
| candidates = [ChannelDgram.make_pair_1] |
| if (os.getenv("MODERNMP_FORCE_DGRAM") != "1" and |
| SOCK_SEQPACKET and |
| MSG_DONTWAIT): |
| candidates.append(ChannelSeqPacket.make_pair_1) |
| while candidates: |
| candidate = candidates.pop() |
| try: |
| channels = candidate() |
| except OSError as ex: |
| if ex.errno != ESOCKTNOSUPPORT: |
| raise |
| else: |
| cls.__make_pair_1 = candidate |
| return channels |
| raise RuntimeError("no way to create sockets") |
| |
| @classmethod |
| def make_pair(cls, duplex: bool = True): |
| """Make a pair of connected channels |
| |
| Return a pair (CHANNEL_A, CHANNEL_B). |
| |
| If DUPLEX is false, the two channels are indistinguishable. |
| If DUPLEX is true, data can flow only from CHANNEL_B to CHANNEL_A. |
| """ |
| channels = cls.__make_pair_1() |
| if not duplex: |
| channels[0].shutdown(SHUT_WR) |
| channels[1].shutdown(SHUT_RD) |
| return channels |
| |
| def __init__(self, sock): |
| """Private constructor: use the make_pair() constructor""" |
| assert isinstance(sock, socket) |
| assert sock.family == AF_UNIX |
| assert sock.type in (SOCK_SEQPACKET, SOCK_DGRAM) |
| self._socket = sock |
| |
| def __getstate__(self): |
| cacheless_dict = self.__dict__.copy() |
| for name, value in self.__class__.__dict__.items(): |
| if isinstance(value, cached_property): |
| cacheless_dict.pop(name, None) |
| return cacheless_dict |
| |
| @cached_property |
| def fd_numbers(self): |
| """Tuple of all file descriptors that we care about""" |
| return self._socket.fileno(), |
| |
| @property |
| def data_socket(self): |
| """The data socket, for commands""" |
| return self._socket |
| |
| def sendmsg_raw(self, |
| pickle_data: bytes, |
| ancdata: bytes, |
| block: bool) -> None: |
| """Send a message""" |
| if __debug__: |
| self.__check_blocking(block) |
| flags = SENDMSG_BLOCK_FLAGS if block else SENDMSG_NONBLOCK_FLAGS |
| ret = self._socket.sendmsg([pickle_data], ancdata, flags) |
| assert ret == len(pickle_data) |
| |
| def poll_for_send(self): |
| """Wait until the channel becomes ready for send""" |
| return self._send_poller.poll() |
| |
| @cached_property |
| def _send_poller(self): |
| poll = select.poll() |
| poll.register(self._socket, select.POLLOUT) |
| return poll |
| |
| def send(self, |
| obj, |
| make_message_send_context=MessageSendContext, |
| *, |
| block=True): |
| """Send an object through this connection |
| |
| OBJ is the object to send, which we'll send with fancy_pickle. |
| If the pickling generates file descriptors, send them along with the |
| message using SCM_RIGHTS. |
| |
| BLOCK controls whether the send blocks. |
| |
| MAKE_MESSAGE_SEND_CONTEXT is a function of no arguments that |
| returns a MessageSendContext object, which also becomes the |
| pickle context. |
| |
| If the target ceases to exist while we're sending the message, raise |
| PeerDiedException. (We fold socket errors and resource manager |
| errors into this one exception type.) All other exceptions are |
| serious internal errors. |
| |
| Return true if we sent a message and false if we successfully |
| failed to send a message. The latter case occurs only when BLOCK |
| is false and sending the message would have blocked. |
| """ |
| pickle_data, pc = fancy_pickle( |
| obj, |
| size_limit=MAX_MSG_SIZE, |
| make_pickle_context=make_message_send_context) |
| assert isinstance(pc, MessageSendContext) |
| assert len(pickle_data) <= MAX_MSG_SIZE |
| try: |
| pc.sendmsg(pickle_data, block, self) |
| except BlockingIOError: |
| assert not block |
| return False |
| except ConnectionError: |
| raise PeerDiedException |
| except OSError as ex: |
| if ex.errno == ENOTCONN: |
| raise PeerDiedException |
| raise |
| return True |
| |
| def __check_blocking(self, block): |
| if __debug__: |
| fd_is_blocking = os.get_blocking(self._socket.fileno()) |
| if block: |
| assert fd_is_blocking |
| else: |
| assert MSG_DONTWAIT or not fd_is_blocking |
| |
| def _recvmsg(self, data_buf, block): |
| if __debug__: |
| self.__check_blocking(block) |
| nbytes, ancdata, recv_flags, _address = \ |
| self._socket.recvmsg_into([data_buf], |
| MSG_ANC_SIZE, |
| (RECVMSG_BLOCK_FLAGS |
| if block else RECVMSG_NONBLOCK_FLAGS)) |
| if recv_flags & RECV_FLAGS_TRUNCATED: |
| raise RuntimeError("incoming message truncated") |
| return nbytes, ancdata |
| |
| def recv(self, make_unpickle_context=UnpickleContext, *, block=True): |
| """Receive an object sent with send() |
| |
| BLOCK controls whether the receive blocks. |
| |
| Impotant! If BLOCK is false and reading a message would block, we |
| raise BlockingIOError! You might think we could return a sentinel |
| (like None) in this case, but if we did that, we couldn't |
| distinguish receiving the sentinel from failing to read a message. |
| |
| If the other end of the stream is disconnected, raise |
| PeerDiedException. Raise BlockingIOError (N.B. not return None!) |
| if BLOCK is true and we can't read a message right now. If BLOCK is |
| false, block until we get a message or until some other error |
| occurs. |
| """ |
| data_buf = bytearray(MAX_MSG_SIZE) # TODO(dancol): POOL! |
| try: |
| nbytes, ancdata = self._recvmsg(data_buf, block) |
| except ConnectionError: |
| raise PeerDiedException |
| if not nbytes: |
| raise PeerDiedException |
| return fancy_unpickle( |
| memoryview(data_buf)[:nbytes], |
| _extract_fds_from_ancdata(ancdata) if ancdata else (), |
| make_unpickle_context) |
| |
| def shutdown(self, direction): |
| """Shut down the connection in one direction |
| |
| DIRECTION is SHUT_RD or SHUT_WR.""" |
| assert direction in (SHUT_RD, SHUT_WR) |
| self._socket.shutdown(direction) |
| |
| def _do_close(self): |
| self._socket.close() |
| super()._do_close() |
| |
| def __repr__(self): |
| return "<{} {!r}>".format(type(self), self._socket) |
| |
| class ChannelDgram(Channel): |
| """SOCK_SEQPACKET emulating channel based on SOCK_DGRAM""" |
| |
| @classmethod |
| def make_pair_1(cls): |
| """Implementation of make_pair""" |
| |
| socket_a, socket_b = socketpair(AF_UNIX, SOCK_DGRAM) |
| canary_a, canary_b = socketpair(AF_UNIX, SOCK_STREAM) |
| return cls(socket_a, canary_a), cls(socket_b, canary_b) |
| |
| def __init__(self, sock, canary): |
| super().__init__(sock) |
| self.__canary = the(socket, canary) |
| for fd_number in self.fd_numbers: |
| os.set_blocking(fd_number, False) # Can't rely on MSG_DONTWAIT |
| |
| @cached_property |
| def fd_numbers(self): |
| numbers = self._socket.fileno(), self.__canary.fileno() |
| return numbers |
| |
| @cached_property |
| def __recv_poller(self): |
| poll = select.poll() |
| for thing in self._socket, self.__canary: |
| poll.register(thing.fileno(), select.POLLIN) |
| return poll |
| |
| def _recvmsg(self, data_buf, block): |
| # Emulate blocking IO with a poll loop so we can watch for canary |
| # disconnect too. |
| while True: |
| try: |
| return super()._recvmsg(data_buf, False) |
| except BlockingIOError: |
| pass |
| if not block: |
| self.__canary.recv(1) # Will raise BlockingIOError... |
| raise PeerDiedException # ...unless peer has died |
| else: |
| try: |
| self.__canary.recv(1) |
| except BlockingIOError: |
| self.__recv_poller.poll() |
| else: |
| raise PeerDiedException |
| |
| def sendmsg_raw(self, pickle_data, ancdata, block): |
| if not block: |
| super().sendmsg_raw(pickle_data, ancdata, block) |
| else: |
| while True: # Emulate blocking mode |
| try: |
| super().sendmsg_raw(pickle_data, ancdata, False) |
| except BlockingIOError: |
| pass |
| self._send_poller.poll() |
| |
| def shutdown(self, direction): |
| super().shutdown(direction) |
| self.__canary.shutdown(direction) |
| |
| def _do_close(self): |
| self.__canary.close() |
| super()._do_close() |
| |
| def __repr__(self): |
| return "<{} {!r} {!r}>".format(type(self), self._socket, self.__canary) |
| |
| class ChannelSeqPacket(Channel): |
| """Channel using SOCK_SEQPACKET""" |
| |
| @classmethod |
| def make_pair_1(cls): |
| """Implement make_pair for seqpacket connnections""" |
| |
| socket_a, socket_b = socketpair(AF_UNIX, SOCK_SEQPACKET) |
| return cls(socket_a), cls(socket_b) |