| # Copyright (C) 2020 The Android Open Source Project |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http:#www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """A slower, shittier, and buggier version of Binder""" |
| |
| # pylint: disable=bare-except |
| |
| import itertools |
| import logging |
| import os |
| import select |
| import threading |
| import weakref |
| |
| from contextlib import contextmanager |
| from functools import wraps, partial |
| from multiprocessing.util import _run_after_forkers |
| from socket import SHUT_RD |
| from fcntl import ioctl |
| from termios import TIOCSCTTY |
| |
| from .reduction import ( |
| PickleContext, |
| UnpickleContext, |
| fancy_pickle, |
| fancy_unpickle, |
| ) |
| |
| from .channel import ( |
| Channel, |
| MessageSendContext, |
| PeerDiedException, |
| fdhs_to_ancdata, |
| ) |
| |
| from .waitsem import WaitableSemaphore |
| |
| from .util import ( |
| CannotPickle, |
| ChildExitWatcher, |
| ClosingContextManager, |
| FdHolder, |
| MultiDict, |
| NoneType, |
| ReferenceCountTracker, |
| SafeClosingObject, |
| cached_property, |
| capture_exc_info, |
| die_due_to_fatal_exception, |
| is_lock_owned_by_current_thread, |
| once, |
| rename_code_object, |
| reraise_exc_info, |
| the, |
| tls, |
| unix_pipe, |
| ) |
| |
| logc = logging.getLogger(__name__ + ".client") # pylint: disable=invalid-name |
| logc_tx = logging.getLogger(__name__ + ".client.tx") # pylint: disable=invalid-name |
| logr = logging.getLogger(__name__ + ".resman") # pylint: disable=invalid-name |
| |
| # Maximum number of notifications we batch into a single notification |
| # packet. Should be small enough that the resulting pickled message |
| # still fits into a MAX_MSG_SIZE-byte datagram. |
| MAX_NOTIFICATION_BATCH = 32 |
| |
| # Maximum number of notification packets we allow to sit in a resmanc |
| # socket buffer before we require explicit acknowledgment. |
| # TODO(dancol): use a semaphore instead of explicit windowing? |
| MAX_NOTIFICATION_UNACKED = 2 |
| |
| MSG_REPLY_SUCCESS = 0 |
| MSG_REPLY_ERROR = 1 |
| MSG_COMMAND = 2 |
| MSG_ACK_NOTIFICATION = 3 |
| MSG_NOTIFICATION_MIN = 4 |
| MSG_NOTIFICATION_OID_GONE = MSG_NOTIFICATION_MIN + 0 |
| MSG_NOTIFICATION_QUIT_THREAD = MSG_NOTIFICATION_MIN + 1 |
| |
| OID_INVALID = 0 # So bool evaluation of UIDs doesn't get confused |
| OID_SELF_CONNECTION = 1 # Resource manager's connection to itself |
| OID_INITIAL_CONNECTION_ID = 2 # Initial resource manager client |
| |
| # Useful hack to run everything in a single process for better |
| # logging, debugging, test startup performance, or something. |
| # Normally prefer the most robust multi-process approach. |
| USE_RESMANC_THREAD = os.getenv("MODERNMP_USE_RESMANC_THREAD") == "1" |
| |
| def _configure_resman_logging(): |
| shm_verbose = int(os.getenv("MODERNMP_SHM_VERBOSE", "0")) |
| logr.setLevel(logging.INFO) |
| logc.setLevel(logging.INFO) |
| logc_tx.setLevel(logging.INFO) |
| if shm_verbose > 0: # pylint: disable=compare-to-zero |
| logr.setLevel(logging.DEBUG) |
| logc.setLevel(logging.DEBUG) |
| if shm_verbose > 1: |
| logc_tx.setLevel(logging.DEBUG) |
| _configure_resman_logging() |
| |
| _resman_api_functions = [] # pylint: disable=invalid-name |
| _resman_api_counts = None |
| |
| class ResourceManagerError(Exception): |
| """Base exception for expected resource manager errors |
| |
| Any exception derived from this one is propagated from the resource |
| manager to clients; other exceptions instantly kill both the |
| resource manager process and (indirectly, through resource manager |
| death) all clients. |
| """ |
| |
| class ConnectionNotFoundError(ResourceManagerError): |
| """Exception thrown when we can't locate a connection""" |
| |
| class ResourceConstructionFailedError(ResourceManagerError): |
| """Exception thrown when building a resource fails. |
| |
| The original exception is chained.""" |
| |
| class SendMsgError(ResourceManagerError): |
| """Exception thrown when sendmsg goes wrong""" |
| |
| class _DoForkChildException(BaseException): |
| """Special exception to break out of event loop in fork child""" |
| def __init__(self, info): |
| super().__init__("fork child") |
| self.info = info |
| |
| def _assert_valid_oid(oid): |
| assert the(int, oid) > 0 # pylint: disable=compare-to-zero |
| return True |
| |
| @once() |
| def get_global_process_resmanc(): |
| """Get the process-global resource manager""" |
| name = "resman/{}".format(os.getpid()) |
| return (start_resource_manager_thread if USE_RESMANC_THREAD \ |
| else start_resource_manager_process)(name) |
| |
| @once() |
| def get_temp_name_sequence(): |
| """Get an infinite generator of random names""" |
| # pylint: disable=protected-access |
| from tempfile import _RandomNameSequence |
| return _RandomNameSequence() |
| |
| if USE_RESMANC_THREAD: |
| # Needed a TLS lookup so that the resmanc thread can use a different |
| # client instance when talking to itself. |
| def get_process_resmanc(): |
| """Get the main process connection to the resource manager |
| |
| Despite the name, may be thread-local. A thread-local resource |
| manager is useful primarily when running the resource manager as a |
| separate thread in the current process, which we do when we run the |
| resource manager tests. |
| """ |
| return tls.resmanc or get_global_process_resmanc() |
| else: |
| # Not using threaded mode, so can always just use the global, even |
| # in the resman server. |
| get_process_resmanc = get_global_process_resmanc |
| |
| def get_oid(obj): |
| """Return the OID of OBJ if managed, else None.""" |
| # TODO(dancol): get rid of this function |
| assert isinstance(obj, SharedObject) |
| return obj.oid |
| |
| def _resman_api(*, oneway=False): |
| def _decorator(fn): |
| fn.oneway = oneway |
| _resman_api_functions.append(fn) |
| return fn |
| return _decorator |
| |
| def _do_cache_newborn(obj_type, obj_dict): |
| return get_process_resmanc().internal_cache_newborn( |
| obj_type, obj_dict) |
| |
| class _RegisterNewbornThunk(object): |
| """Protect a newborn object from reference pickling. |
| |
| In functions like ServerConnection.adopt_resource, we return a |
| resource value to the client. The return value of this function gets |
| pickled along with the rest of the response payload in |
| ServerConnection.__send. Ordinarily, when we pickle an object we |
| know about, we encode it as a reference. The client, in the process |
| of unpickling the reference, talks to the resource manager and gets |
| the real object payload. If we went through this pickling process |
| for the return value of ServerConnection.adopt_resource, we'd |
| recurse forever. |
| |
| Instead, we return an instance of this object from ServerConnection |
| functions that return resources "by value". A _RegisterNewbornThunk |
| pickles to a special callback function that causes the recipient to |
| register the object instance in its own process resource manager |
| client. In addition, we protect the referenced object from being |
| pickled as a pointer, preventing infinite recursion during |
| unpickling. |
| """ |
| def __init__(self, obj): |
| assert isinstance(obj, SharedObject) |
| assert obj.oid |
| self.obj = obj |
| |
| def __reduce__(self): |
| obj = self.obj |
| return _do_cache_newborn, (type(obj), obj.__dict__) |
| __slots__ = ["obj"] |
| |
| class ServerConnection(CannotPickle): |
| """A ResourceManagerServer-side connection to a client""" |
| |
| __dead = False |
| """Whether this connnection has failed and is being cleaned up""" |
| |
| __nr_notifications_unacked = 0 |
| """Number of notifications the peer hasn't acknowledged""" |
| |
| __critical = False |
| """Kill the whole server if this connection dies""" |
| |
| def __init__(self, channel, server, name): |
| self.__channel = the((Channel, NoneType), channel) |
| self.__name = the(str, name) |
| self.__server = the(ResourceManagerServer, server) |
| self.__id = the(int, server.allocate_oid()) |
| self.__references = ReferenceCountTracker() |
| self.__subscribed = set() |
| self.__notifications = [] |
| self.fd_numbers = channel.fd_numbers if channel else () |
| |
| @property |
| def name(self): |
| """Name of this server connection""" |
| return self.__name |
| |
| @cached_property |
| def id(self): |
| """The ID of this connection""" |
| return self.__id |
| |
| @property |
| def dead(self): |
| """Whether this connection is marked for removal""" |
| return self.__dead |
| |
| def __lookup_conn(self, connection_id): |
| """Find a connection by id. |
| |
| CONNECTION_ID should be a numeric connection id. If |
| it is None, return self instead.""" |
| if connection_id: |
| return self.__server.find_connection_by_id(connection_id) |
| return self |
| |
| def fork(self, fn, process_resmanc_factory): |
| """Fork the resman process: see LocalClient.fork()""" |
| return self.__server.fork(fn, process_resmanc_factory) |
| |
| @_resman_api() |
| def make_connection_factory(self, name): |
| """Make a new resource manager connection factory. |
| |
| We return a factory instead of a connection because the factory |
| can be sent between processes, while a full connection cannot |
| be. The returned factory is good for building one |
| ServerConnection; it is useless after being called. |
| """ |
| assert isinstance(name, str) |
| # On error, the client socket disappears and we clean up elegantly |
| client_channel, server_channel = Channel.make_pair() |
| server_connection = self.__server.add_connection(server_channel, name) |
| client_connection_factory = ClientConnectionFactory( |
| client_channel, |
| server_connection.id, |
| name) |
| if __debug__: |
| logr.debug("new connection factory %r", client_connection_factory) |
| return client_connection_factory |
| |
| @_resman_api() |
| def make_resource(self, resource_type, args, kwargs): |
| """Allocate a resource in the resman process. |
| |
| RESOURCE_TYPE is a subclass of SharedObject. ARGS and KWARGS are |
| given to its __init__ to create a new object. If constructing the |
| resource object raise an exception, raise a |
| ResourceConstructionFailedError with the original exception as its |
| __cause__. |
| |
| The pickle facility of the resource object will be used only when |
| transferring a resource instance from the resource manager to a |
| caller; when processes send resource objects to each other, they |
| send resource pointers. These resource pointers magically retrieve |
| their referenced objects from the resource manager |
| during unpickling. |
| |
| Return an instance of the resource's client object. The calling |
| connection owns a single reference to the resource. |
| """ |
| assert issubclass(resource_type, SharedObject) |
| try: |
| resource = object.__new__(resource_type) |
| resource_type.__init__(resource, *args, **kwargs) |
| except Exception as ex: |
| raise ResourceConstructionFailedError() from ex |
| oid = self.__server.add_resource(resource) |
| if __debug__: |
| logr.debug("made resource of type %r oid=%r", resource_type, oid) |
| assert oid == resource.oid |
| assert self.__server.get_refcount(oid) |
| old_rc = self.__references.addref(oid) |
| assert not old_rc |
| return _RegisterNewbornThunk(resource) |
| |
| def add_local_refcount(self, oid): |
| """Add a reference to resource OID""" |
| assert _assert_valid_oid(oid) |
| if __debug__: |
| logr.debug("add_local_refcount owner=%r oid=%r newrc=%r", |
| self.id, oid, self.__references.get(oid, 0) + 1) |
| if not self.__references.addref(oid): |
| self.__server.addref_unchecked(oid) |
| |
| def add_local_refcounts(self, oids): |
| """Batch add reference counts""" |
| for oid in oids: |
| self.add_local_refcount(oid) |
| |
| def release_local_refcounts(self, oids): |
| """Batch release""" |
| for oid in oids: |
| self.release_resource(oid) |
| |
| @_resman_api(oneway=True) |
| def addref_for_test(self, oid, target_oid=None): |
| """Test function: add an artificial reference count to OID""" |
| assert self.has_reference(oid) |
| self.__lookup_conn(target_oid).add_local_refcount(oid) |
| |
| @_resman_api() |
| def adopt_resource(self, oid): |
| """Retrieve an instance of an existing resource. |
| |
| Return (from client's perspective) a copy of the resource's held |
| object, wired up to call _do_cache_newborn. The caller's |
| reference to OID, which must exist, is reused. The resource's |
| reference count is unchanged. |
| """ |
| if __debug__: |
| logr.debug("adopting oid=%r", oid) |
| assert self.has_reference(oid) |
| return _RegisterNewbornThunk(self.__server.get_resource_unchecked(oid)) |
| |
| @_resman_api(oneway=True) |
| def link_to_death_internal(self, oid): |
| """Receive a notification when resource OID disappears. |
| |
| The caller does not need to have a reference to OID at the time of |
| call. If OID isn't currently alive, send a death |
| notification immediately. |
| |
| Idempotent in the OID-is-alive case: no matter how many times |
| link_to_death is called for a given OID, only one death |
| notification is received. |
| """ |
| assert self.__channel, "LocalClient cannot be notified" |
| if oid not in self.__subscribed: |
| if self.__server.add_subscriber(oid, self): |
| self.__subscribed.add(oid) |
| else: |
| self.notify(MSG_NOTIFICATION_OID_GONE, oid) |
| |
| @_resman_api(oneway=True) |
| def unlink_from_death_internal(self, oid): |
| """Undo the effect of link_to_death_internal. |
| |
| If no death link exists, do nothing. |
| """ |
| if oid in self.__subscribed: |
| self.__subscribed.remove(oid) |
| self.__server.remove_subscriber_unchecked(oid, self) |
| |
| @_resman_api(oneway=True) |
| def release_resource(self, oid, decrement_by=1): |
| """Drop references to OID. |
| |
| The calling connection must have a reference to the resource. |
| """ |
| if __debug__: |
| logr.debug("removing %r references from OID %r", decrement_by, oid) |
| if not self.__references.release(oid, decrement_by): |
| self.__server.release_resource_unchecked(oid) |
| |
| def has_reference(self, oid): |
| """Return whether this connection has a reference to OID""" |
| return oid in self.__references |
| |
| @_resman_api() |
| def get_refcounts(self, oid, resmanc=None): |
| """Return reference counts for OID. |
| |
| Return a tuple (LOCAL_REFCOUNT, GLOBAL_REFCOUNT). |
| |
| USE FOR TESTING ONLY. |
| |
| Connections manage their reference counts internally. All the |
| references a particular connection owns for a resource are |
| represented by a single increment of the global reference count |
| for that resource. Consequently, a resource's global reference |
| count is the number of connections that refer to it. |
| """ |
| # pylint: disable=protected-access |
| assert _assert_valid_oid(oid) |
| conn = self.__lookup_conn(resmanc) |
| return (conn.__references.get(oid, 0), |
| conn.__server.get_refcount(oid)) |
| |
| @_resman_api() |
| def getpid(self): # pylint: disable=no-self-use |
| """Return the PID of the resource manager process.""" |
| return os.getpid() |
| |
| @_resman_api() |
| def get_exit_fd(self): |
| """Return a file descriptor that we can use to wait for server exit |
| |
| Works whether the server is a process or a thread. |
| """ |
| return self.__server.get_exit_fd() |
| |
| @_resman_api() |
| def snapshot_api_counts(self): |
| """Start tracking API call counts for test purposes""" |
| return self.__server.snapshot_api_counts() |
| |
| @_resman_api(oneway=True) |
| def explicit_resource_close(self, oid): |
| """Run early resource cleanup for resource OID |
| |
| The calling connection must have a reference to the resource. |
| """ |
| assert oid in self.__references |
| self.__server.explicit_resource_close(oid) |
| |
| @_resman_api() |
| def make_shared_memory_fd(self): # pylint: disable=no-self-use |
| """Make a shared memory segment and return an FdHolder for it""" |
| # We can make and unlink immediately here, so there's no need for |
| # resman-side cleanup via SharedObject. |
| from posix_ipc import SharedMemory, ExistentialError |
| shm = None |
| while not shm: |
| name = "/mmp.shm." + next(get_temp_name_sequence()) |
| try: |
| shm = SharedMemory(name, os.O_CREAT | os.O_EXCL) |
| except ExistentialError: |
| pass |
| fdh = FdHolder.steal(shm.fd) |
| assert not os.get_inheritable(shm.fd) |
| shm.unlink() |
| return fdh |
| |
| @_resman_api() |
| def cg_sendmsg_nonblock(self, cg_oid, data, fdhs, resources): |
| """Atomic message post and refcount manipulation |
| |
| CG_OID is the OID of a CallGate object, defined in apartment.py. |
| The caller must have a reference to this object. |
| |
| On successful send, return True. If the send would block, return |
| False. On socket errors, raise SendMsgError with the real error |
| chained in __cause__. If the CallGate object's target has ceased |
| to exist, raise a SendMsgError chained to PeerDiedException. |
| """ |
| assert self.has_reference(cg_oid) |
| assert all(self.has_reference(oid) for oid in resources) |
| cg = self.__server.get_resource_unchecked(cg_oid) |
| try: |
| dst = self.__server.find_connection_by_id_unchecked(cg.process_oid) |
| except KeyError: |
| raise SendMsgError from PeerDiedException() |
| dst.add_local_refcounts(resources) |
| try: |
| cg.channel.sendmsg_raw(data, fdhs_to_ancdata(fdhs), False) |
| except BlockingIOError: |
| dst.release_local_refcounts(resources) |
| return False |
| except OSError as ex: |
| dst.release_local_refcounts(resources) |
| raise SendMsgError from ex |
| return True |
| |
| @_resman_api() |
| def get_critical(self): |
| """Get the critical flag for this connection""" |
| return self.__critical |
| |
| @_resman_api(oneway=True) |
| def set_critical(self, is_critical: bool): |
| """Set whether this connection brings down the whole service""" |
| self.__critical = is_critical |
| |
| @_resman_api() |
| def get_process_status(self, oid): |
| """Get the exit status of process OID. |
| |
| OID must refer to a process object. |
| |
| This function is useful mostly for testing that process reaping is |
| actually working. |
| """ |
| assert self.has_reference(oid) |
| # pylint: disable=protected-access |
| return self.__server \ |
| .get_resource_unchecked(oid) \ |
| ._resman_get_process_status() |
| |
| def force_close(self): |
| """Force close a connection""" |
| if not self.dead and self.__channel: |
| self.__channel.shutdown(SHUT_RD) |
| |
| def force_drop_references(self): |
| """Forcibly dereference everything this connection owns. |
| |
| Used during connection teardown. |
| """ |
| assert self.dead |
| for subscribed_oid in list(self.__subscribed): |
| self.unlink_from_death_internal(subscribed_oid) |
| assert not self.__subscribed # Mutated above |
| for oid in tuple(self.__references): |
| del self.__references[oid] |
| self.__server.release_resource_unchecked(oid) |
| |
| def final_cleanup_on_death(self): |
| """Last method called by server before forgetting object""" |
| if self.__channel: |
| self.__channel.close() |
| if self.__critical: |
| self.__server.force_close_all_connections() |
| |
| def mark_dead(self): |
| """Start connection cleanup.""" |
| if not self.__dead: |
| self.__dead = True |
| self.__notifications.clear() |
| self.__server.on_connection_dead(self) |
| |
| def __recv(self, *, block=True): |
| try: |
| return self.__channel.recv( |
| ResmanClientRecvContext.from_(self), |
| block=block) |
| except BlockingIOError: |
| return None |
| except PeerDiedException: |
| self.mark_dead() |
| return None |
| |
| def __send(self, obj, block=True): |
| try: |
| return self.__channel.send(obj, |
| FromResmanSendContext.to(self), |
| block=block) |
| except PeerDiedException: |
| self.mark_dead() |
| return False |
| |
| def notify(self, code, oid): |
| """Enqueue a notification to be delivered""" |
| assert self.__channel |
| if not self.__dead: |
| if (self.__notifications and |
| self.__notifications[-1][0] == code and |
| len(self.__notifications[-1][1]) < MAX_NOTIFICATION_BATCH): |
| self.__notifications[-1][1].append(oid) |
| else: |
| self.__notifications.append((code, [oid])) |
| self.send_pending_notifications() |
| |
| def send_pending_notifications(self): |
| """Send some notifications, if any, to socket""" |
| while (self.__notifications and |
| self.__nr_notifications_unacked < MAX_NOTIFICATION_UNACKED and |
| self.__send(self.__notifications[0], block=False)): |
| if __debug__: |
| logr.debug("sent notification to socket") |
| self.__nr_notifications_unacked += 1 |
| self.__notifications.pop(0) |
| |
| def _invoke_api(self, msg): |
| _, api_nr, args, kwargs = msg |
| fn = _resman_api_functions[api_nr] |
| if __debug__: |
| logr.debug("API call from client %s to %s", |
| self.id, fn.__name__) |
| if fn.oneway: |
| fn(self, *args, **kwargs) |
| return None |
| try: |
| return (MSG_REPLY_SUCCESS, fn(self, *args, **kwargs)) |
| except ResourceManagerError: |
| return (MSG_REPLY_ERROR, capture_exc_info()) |
| |
| def handle_transaction(self): |
| """Receive message from client and reply. |
| |
| Return whether this connection should stay alive. |
| """ |
| while True: |
| msg = self.__recv(block=False) |
| if msg is None: |
| break |
| elif msg[0] == MSG_COMMAND: |
| reply = self._invoke_api(msg) |
| if reply: |
| self.__send(reply) |
| elif msg[0] == MSG_ACK_NOTIFICATION: |
| _, nr_acked = msg |
| assert isinstance(nr_acked, int) |
| assert nr_acked <= self.__nr_notifications_unacked |
| self.__nr_notifications_unacked -= nr_acked |
| else: |
| raise AssertionError("unknown message {!r}".format(msg)) |
| |
| def __repr__(self): |
| return "ServerConnection(name={!r}, id={!r}, channel={!r})".format( |
| self.name, self.id, self.__channel) |
| |
| def _rebuild_resource(oid): |
| assert isinstance(tls.unpickle_context, ResourceRecvContext) |
| return tls.unpickle_context.rebuild_resource(oid) |
| |
| class ShmSharedMemorySendContextMixin(MessageSendContext): |
| """Message sending mixin that makes shared memory with resman""" |
| |
| @staticmethod |
| def make_shared_memory_fd() -> FdHolder: |
| """Make a memory file descriptor""" |
| # TODO(dancol): use Linux memfd where available |
| return get_process_resmanc().make_shared_memory_fd() |
| |
| class AbstractResourceReducer(PickleContext): |
| """Interface for reducing resources in pickling""" |
| def reduce_resource(self, obj): |
| """__reduce__ for a SharedObject""" |
| raise NotImplementedError("abstract") |
| |
| class ResourceSendingMessageContextMixin(AbstractResourceReducer): |
| """Message sending mixin that accumulates a list of resources to send""" |
| |
| def __init__(self): |
| super().__init__() |
| self.resources = [] |
| |
| def reduce_resource(self, obj): |
| """__reduce__ for a SharedObject""" |
| # pylint: disable=protected-access |
| self.resources.append(obj.oid) |
| if obj._resman_send_eager: |
| return _RegisterNewbornThunk(obj).__reduce__() |
| return _rebuild_resource, (obj.oid,) |
| |
| class ToResmanSendContext( |
| ShmSharedMemorySendContextMixin, |
| MessageSendContext, |
| AbstractResourceReducer, |
| ): |
| """Message sending context for sending messages to resman itself""" |
| def reduce_resource(self, obj): |
| """__reduce__ for a SharedObject""" |
| # No need to do anything here: we don't have to tell the server to |
| # add a reference when it has all possible references already. |
| return _rebuild_resource, (obj.oid,) |
| |
| class FromResmanSendContext( |
| ShmSharedMemorySendContextMixin, |
| MessageSendContext, |
| ResourceSendingMessageContextMixin, |
| ): |
| """Message sending context that adds references to a connection |
| |
| Usable only from the resman server itself. |
| """ |
| |
| def __init__(self, to): |
| super().__init__() |
| self.__to = the(ServerConnection, to) |
| |
| @classmethod |
| def to(cls, to): |
| """Make a constructing function""" |
| return partial(cls, to) |
| |
| def sendmsg(self, pickle_data, block, channel): |
| self.__to.add_local_refcounts(self.resources) |
| try: |
| super().sendmsg(pickle_data, block, channel) |
| except: |
| self.__to.release_local_refcounts(self.resources) |
| raise |
| |
| class ResourceRecvContext(UnpickleContext): |
| """Receive context for receiving messages""" |
| @staticmethod |
| def rebuild_resource(oid): |
| """Translate a pickled OID into an object""" |
| return get_process_resmanc().load_resource(oid) |
| |
| class ResmanClientRecvContext(ResourceRecvContext): |
| """Receive resources from a resman client connection""" |
| def __init__(self, *, from_, **kwargs): |
| super().__init__(**kwargs) |
| self.__from = the(ServerConnection, from_) |
| |
| @classmethod |
| def from_(cls, from_): |
| """Make a constructing function""" |
| return partial(cls, from_=from_) |
| |
| def rebuild_resource(self, oid): |
| # We're the resource server, so there's no point in the client |
| # sending a separate message to addref ourselves. Just add |
| # virtual references on demand. |
| assert self.__from.has_reference(oid) |
| get_process_resmanc().server_connection.add_local_refcount(oid) |
| return super().rebuild_resource(oid) |
| |
| class ResourceManagerServer(object): |
| """Provides resource manager server side""" |
| |
| def __init__(self): |
| self.__connections_by_fileno = {} |
| self.__connections_by_id = {} |
| self.__resources = {} |
| self.__refcounts = {} |
| self.__subscribers = MultiDict() |
| self.__dead_connections = [] |
| self.__poll = select.poll() # TODO(dancol): use selectors |
| self.__oid_gen = itertools.count(1) |
| self.__randname_sequence = None |
| self.__exit_fd = None |
| |
| def allocate_oid(self): |
| """Make a new unique ID for connections, resources, etc.""" |
| return next(self.__oid_gen) |
| |
| def __all_connections(self): |
| return self.__connections_by_id.values() |
| |
| def add_connection(self, server_channel, name) -> ServerConnection: |
| """Add a new connection. |
| |
| SERVER_CHANNEL is the server side of the channel pair created |
| using Channel.make_pair(). |
| """ |
| connection = ServerConnection(server_channel, self, name) |
| if server_channel: |
| for fd_number in connection.fd_numbers: |
| assert fd_number not in self.__connections_by_fileno |
| self.__connections_by_fileno[fd_number] = connection |
| self.__poll.register(fd_number, select.POLLIN) |
| self.__connections_by_id[connection.id] = connection |
| if __debug__: |
| logr.debug("added connection fd_numbers=%r id=%r", |
| connection.fd_numbers, connection.id) |
| return connection |
| |
| def add_resource(self, resource): |
| """Add a new resource RESOURCE to the server. |
| |
| RESOURCE can be of any type derived from SharedObject. |
| |
| Return the new resource ID. |
| """ |
| assert isinstance(resource, SharedObject) |
| assert not resource.oid |
| resource.oid = oid = self.allocate_oid() |
| assert oid not in self.__resources |
| self.__refcounts[oid] = 1 |
| self.__resources[oid] = resource |
| return oid |
| |
| def addref_unchecked(self, oid): |
| """Add a reference to resource OID""" |
| assert oid in self.__resources |
| logr.debug("adding global refcount oid=%r newrc=%r", oid, |
| self.__refcounts.get(oid, 0) + 1) |
| self.__refcounts[oid] += 1 |
| |
| def has_resource(self, oid): |
| """Return whether we know about resource OID""" |
| assert _assert_valid_oid(oid) |
| return oid in self.__resources |
| |
| def add_subscriber(self, oid, connection): |
| """Notify CONNECTION when OID disappears |
| |
| If OID corresponds to a resource or a connection, set up the |
| subscription link. Otherwise, return false. |
| |
| Subscription additions are idempotent on success. |
| """ |
| assert isinstance(connection, ServerConnection) |
| if oid in self.__refcounts: |
| assert oid in self.__resources |
| if __debug__: |
| logr.debug("linking resource OID %r -> %r", oid, connection) |
| elif oid in self.__connections_by_id: |
| if __debug__: |
| logr.debug("linking connection OID %r -> %r", oid, connection) |
| else: |
| if __debug__: |
| logr.debug("FAILED linking OID %r: not resource or connection", oid) |
| return False |
| self.__subscribers.add(oid, connection) |
| return True |
| |
| def remove_subscriber_unchecked(self, oid, connection): |
| """Remove notification of resource death |
| |
| OID and CONNECTION are the same as for add_subscriber(). |
| """ |
| self.__subscribers.remove(oid, connection) |
| |
| def get_refcount(self, oid): |
| """Return the reference count of resource OID""" |
| return self.__refcounts.get(oid, 0) |
| |
| def get_resource_unchecked(self, oid): |
| """Return the resource object for id OID.""" |
| return self.__resources[oid] |
| |
| def __death_notify(self, oid): |
| subscribers = self.__subscribers.get(oid) |
| if subscribers: |
| for subscriber in list(subscribers): |
| subscriber.notify(MSG_NOTIFICATION_OID_GONE, oid) |
| subscriber.unlink_from_death_internal(oid) |
| assert not subscribers # unlink_from_death mutated |
| assert oid not in self.__subscribers |
| |
| def release_resource_unchecked(self, oid): |
| """Drop a reference to resource OID""" |
| assert oid in self.__resources |
| refcount = self.__refcounts[oid] - 1 |
| if __debug__: |
| logr.debug("dropping global reference to OID %r: new refcount %r", |
| oid, refcount) |
| if refcount: |
| assert refcount > 0 # pylint: disable=compare-to-zero |
| self.__refcounts[oid] = refcount |
| else: |
| if __debug__: |
| for connection in self.__all_connections(): |
| assert not connection.has_reference(oid), \ |
| "connection {} has stray reference to OID {}".format( |
| connection, oid) |
| self.__death_notify(oid) |
| try: |
| # pylint: disable=protected-access |
| self.__resources[oid]._resman_destroy() |
| except: |
| die_due_to_fatal_exception("destroying resource") |
| del self.__refcounts[oid] |
| del self.__resources[oid] |
| |
| def explicit_resource_close(self, oid): |
| """Explicitly clean up a resource. |
| |
| The resource, after close, still exists, in the sense that its OID |
| remains valid. Its state, however, may transition. |
| """ |
| # pylint: disable=protected-access |
| self.__resources[oid]._resman_explicit_close() |
| |
| def find_connection_by_id_unchecked(self, connection_id): |
| """Find a connection with ID |
| |
| Raise KeyError if connection is not found""" |
| return self.__connections_by_id[connection_id] |
| |
| def find_connection_by_id(self, connection_id): |
| """Find a connection with ID. |
| |
| If the connection isn't found, raise ConnectionNotFoundError |
| """ |
| try: |
| return self.__connections_by_id[connection_id] |
| except KeyError: |
| raise ConnectionNotFoundError(connection_id) from None |
| |
| def get_exit_fd(self): |
| """Return a file descriptor usable for monitoring server exit""" |
| if not self.__exit_fd: |
| self.__exit_fd = unix_pipe() |
| return self.__exit_fd[0] |
| |
| def fork(self, fn, process_resmanc_factory): |
| """See LocalClient.fork()""" |
| assert not USE_RESMANC_THREAD |
| assert tls.resman_server |
| # TODO(dancol): use FD inheritance here instead of a command socket |
| command_read, command_write = Channel.make_pair(duplex=False) |
| pid = _do_fork() |
| if not pid: |
| command_write.close() |
| raise _DoForkChildException([process_resmanc_factory, command_read]) |
| process_resmanc_factory.close() |
| child_connection = self.find_connection_by_id_unchecked( |
| process_resmanc_factory.id) |
| command_read.close() |
| command_write.send(fn, FromResmanSendContext.to(child_connection)) |
| command_write.close() |
| return pid |
| |
| def force_close_all_connections(self): |
| """Force all connections closed. |
| |
| Used when a critical connection dies to force all other |
| connections to close. |
| """ |
| for connection in self.__connections_by_id.values(): |
| connection.force_close() |
| |
| def on_connection_dead(self, connection): |
| """Called by connections becoming dead to queue final death""" |
| assert connection.dead |
| assert connection not in self.__dead_connections |
| if __debug__: |
| logr.debug("Adding connection %r to dead connections", connection) |
| self.__dead_connections.append(connection) |
| |
| def __remove_dead_connection(self, connection): |
| if __debug__: |
| logr.debug("removing dead connection fd_numbers=%r id=%r", |
| connection.fd_numbers, connection.id) |
| assert connection.dead |
| connection.force_drop_references() |
| self.__death_notify(connection.id) |
| del self.__connections_by_id[connection.id] |
| for fd_number in connection.fd_numbers: |
| del self.__connections_by_fileno[fd_number] |
| self.__poll.unregister(fd_number) |
| connection.final_cleanup_on_death() |
| |
| def __flush_dead_connections(self): |
| for connection in self.__dead_connections: |
| self.__remove_dead_connection(connection) |
| self.__dead_connections.clear() |
| |
| def __run_main_loop(self): |
| while self.__connections_by_fileno: |
| if __debug__: |
| logr.debug("loop conn=%r", self.__connections_by_fileno) |
| timeout = None |
| for connection in self.__all_connections(): |
| connection.send_pending_notifications() |
| if self.__dead_connections: |
| self.__flush_dead_connections() |
| timeout = 0 # Keep looping until we're stable |
| poll_result = self.__poll.poll(timeout) |
| for fd, _ in poll_result: |
| connection = self.__connections_by_fileno[fd] |
| connection.handle_transaction() |
| |
| @staticmethod |
| def snapshot_api_counts(): |
| """Start tracking API call counts: useful during tests |
| |
| If tracking is not yet enabled, enable it. |
| """ |
| global _resman_api_functions |
| global _resman_api_counts |
| if _resman_api_counts is None: |
| _resman_api_counts = ReferenceCountTracker() |
| def _wrap(fn): |
| @wraps(fn) |
| def _counting_wrapper(*args, **kwargs): |
| _resman_api_counts.addref(fn.__name__) |
| return fn(*args, **kwargs) |
| return _counting_wrapper |
| _resman_api_functions = [_wrap(fn) for fn in _resman_api_functions] |
| old_counts = dict(_resman_api_counts) |
| _resman_api_counts.clear() |
| return old_counts |
| |
| @staticmethod |
| @contextmanager |
| def __set_up_local_client(self_connection): |
| if USE_RESMANC_THREAD: |
| assert not tls.resmanc |
| with tls.set_resmanc(LocalClient(self_connection)): |
| del self_connection |
| yield |
| tls.resmanc.close() |
| else: |
| assert not get_global_process_resmanc.has_value() |
| get_global_process_resmanc.set(LocalClient(self_connection)) |
| del self_connection |
| try: |
| yield |
| finally: |
| get_global_process_resmanc().close() |
| get_global_process_resmanc.reset() |
| |
| def run(self): |
| """Run the ResourceManagerServer thread. |
| |
| Terminates when all connections (including the initial connection, |
| excluding the self connection) go away. |
| """ |
| # pylint: disable=bad-continuation |
| with tls.set_resman_server(threading.current_thread()), \ |
| self.__set_up_local_client( |
| self.__connections_by_id[OID_SELF_CONNECTION]): |
| self.__run_main_loop() |
| assert len(self.__connections_by_id) == 1 |
| self_connection = self.__connections_by_id[OID_SELF_CONNECTION] |
| self_connection.mark_dead() |
| assert self.__dead_connections == [self_connection] |
| self.__flush_dead_connections() |
| assert not self.__resources |
| assert not self.__dead_connections |
| assert not self.__connections_by_id |
| assert not self.__connections_by_fileno |
| if self.__exit_fd: |
| os.write(self.__exit_fd[1].fileno(), b".") |
| |
| def _run_resource_manager(initial_server_channel_byref): |
| if __debug__: |
| logr.debug("resource manager starting") |
| initial_server_channel, = initial_server_channel_byref |
| initial_server_channel_byref.clear() |
| server = ResourceManagerServer() |
| self_connection = server.add_connection(None, "self") |
| assert self_connection.id == OID_SELF_CONNECTION |
| del self_connection |
| initial_connection = server.add_connection( |
| initial_server_channel, "initial") |
| initial_connection.set_critical(True) |
| assert initial_connection.id == OID_INITIAL_CONNECTION_ID |
| del initial_server_channel |
| del initial_connection |
| try: |
| server.run() |
| except _DoForkChildException: |
| server.__dict__.clear() # Break GC cycles |
| raise |
| if __debug__: |
| logr.debug("resource manager exiting") |
| |
| def _do_fork(): |
| """Call os.fork(), running registered cleanups in child""" |
| assert threading.active_count() == 1, "must not fork with threads" |
| pid = os.fork() |
| if not pid: |
| _run_after_forkers() |
| return pid |
| |
| def _daemonize(): |
| if _do_fork(): |
| os._exit(0) # pylint: disable=protected-access |
| os.setsid() |
| os.chdir("/") |
| |
| def _run_fork_child(fork_info): |
| assert not get_global_process_resmanc.has_value() |
| assert not tls.resmanc |
| assert not hasattr(tls, "resman_server") |
| assert not tls.unpickle_context |
| assert not tls.active_calls |
| process_resmanc_factory, command_read = fork_info |
| fork_info.clear() # Drop references |
| del fork_info |
| get_global_process_resmanc.set(process_resmanc_factory()) |
| del process_resmanc_factory |
| fn = command_read.recv(ResourceRecvContext) |
| command_read.close() |
| del command_read |
| return fn() |
| |
| def _run_resource_manager_thread(initial_server_channel_byref): |
| try: |
| _run_resource_manager(initial_server_channel_byref) |
| except: |
| die_due_to_fatal_exception("running resource manager") |
| |
| def _openpty(): |
| master, slave = os.openpty() |
| return FdHolder.steal(master), FdHolder.steal(slave) |
| |
| def _run_resource_manager_process(initial_server_channel_byref): |
| log_verbose = int(os.getenv("MODERNMP_RESMAN_SERVER_VERBOSE", "0")) >= 1 |
| logging.basicConfig( |
| level=logging.DEBUG if log_verbose else logging.INFO, |
| format="%(process)d/%(threadName)s %(levelname)s:%(name)s:%(message)s") |
| _daemonize() |
| master, slave = _openpty() |
| ioctl(slave, TIOCSCTTY, 0) # TODO(dancol) Fix in resmanc! |
| ChildExitWatcher.install() |
| try: |
| return _run_resource_manager(initial_server_channel_byref) |
| except _DoForkChildException as ex: |
| fork_info = ex.info |
| del master, slave |
| ChildExitWatcher.uninstall() |
| |
| # Fork children end up here. We try to keep the child stack as |
| # shallow as possible. (That's part of the reason we bubble control |
| # flow up to here instead of just doing the normal C thing of |
| # running random stuff below the call to fork(2) and calling _exit() |
| # before anyone tries to return into the ancient pre-fork |
| # stack frames.) |
| return _run_fork_child(fork_info) |
| |
| def start_resource_manager_process(name): |
| """Start a resource manager process. |
| |
| NAME is None or the name of the resource manager client. |
| """ |
| # TODO(dancol): start by fork if we get here early enough |
| client_channel, server_channel = Channel.make_pair() |
| from .spawn import start_subprocess |
| with start_subprocess(partial(_run_resource_manager_process, |
| [server_channel]), name=name) as child: |
| # Child will daemonize immediately, so we shouldn't wait long |
| child.wait() |
| if child.returncode: |
| raise RuntimeError("resource manager process failed to start") |
| server_channel.close() |
| return ClientConnectionFactory( |
| client_channel, |
| OID_INITIAL_CONNECTION_ID, |
| name)() |
| |
| def start_resource_manager_thread(name): |
| """Start the resource manager as a thread. |
| |
| NAME is None or a string name for the resource manager. |
| |
| USE ONLY IN TESTS. |
| """ |
| client_channel, server_channel = Channel.make_pair() |
| server_channel_byref = [server_channel] |
| del server_channel |
| thread = threading.Thread( |
| target=_run_resource_manager, |
| args=(server_channel_byref,), |
| daemon=True) |
| thread.start() |
| factory = ClientConnectionFactory( |
| client_channel, |
| OID_INITIAL_CONNECTION_ID, |
| name) |
| client_connection = factory() |
| return client_connection |
| |
| def _make_resman_api_thunk(api_nr, api_fn, cls): |
| # pylint: disable=protected-access |
| oneway = api_fn.oneway |
| def _thunk(self, *args, **kwargs): |
| return self._transact(api_nr, args, kwargs, oneway) |
| api_name = api_fn.__name__ |
| _thunk.__code__ = rename_code_object(_thunk.__code__, api_name) |
| _thunk.__name__ = api_name |
| _thunk.__qualname__ = cls.__qualname__ + "." + api_name |
| return _thunk |
| |
| def _add_resman_api_thunks(cls): |
| for api_nr, api_fn in enumerate(_resman_api_functions): |
| thunk = _make_resman_api_thunk(api_nr, api_fn, cls) |
| assert not hasattr(cls, api_fn.__name__) |
| setattr(cls, api_fn.__name__, thunk) |
| return cls |
| |
| class ObjCacheWeakRef(weakref.ref): |
| """Weak reference that tracks a connection and OID""" |
| def __new__(cls, ob, callback, connection, oid): |
| assert isinstance(connection, ClientConnectionBase) |
| assert _assert_valid_oid(oid) |
| self = weakref.ref.__new__(cls, ob, callback) |
| self.connection_wr = weakref.ref(connection) |
| self.oid = oid |
| self.refcount = 1 |
| return self |
| |
| def __init__(self, ob, callback, _connection, _oid): |
| super().__init__(ob, callback) |
| |
| __slots__ = "connection_wr", "oid", "refcount" |
| |
| @_add_resman_api_thunks |
| class ClientConnectionBase(ClosingContextManager, CannotPickle): |
| """Base for classes talking to the resource manager.""" |
| def __init__(self, id_): |
| assert the(int, id_) > 0 # pylint: disable=compare-to-zero |
| self._lock = threading.RLock() # Subclasses take lock |
| self.__id = id_ |
| self.__obj_cache = {} |
| |
| def get_net_refcounts(self, oid): |
| """Like get_refcounts, but subtract out accumulated local references |
| |
| Use only for testing. |
| """ |
| with self._lock: |
| local_rc, global_rc = self.get_refcounts(oid) |
| if oid in self.__obj_cache: |
| ref = self.__obj_cache[oid] |
| assert ref.refcount >= local_rc |
| assert ref.refcount |
| local_rc -= ref.refcount - 1 |
| return local_rc, global_rc |
| |
| @cached_property |
| def id(self): |
| """Connection ID""" |
| return self.__id |
| |
| def _do_close(self): |
| # Neuter all the weakref callbacks: we're closing the connection |
| # explicitly, which by side effect releases all references |
| # server-side. |
| with self._lock: |
| for wr in tuple(self.__obj_cache.values()): |
| wr.oid = None |
| self.__obj_cache.clear() |
| super()._do_close() |
| |
| def __clear_stale_obj_locked(self, ref): |
| assert is_lock_owned_by_current_thread(self._lock) |
| oid = ref.oid |
| if oid: |
| # This function can be invoked simultaneously by the weakref |
| # callback and by __addref_cached_obj_locked, so make sure we |
| # actually remove from __obj_cache only once. |
| ref.oid = None |
| del self.__obj_cache[oid] |
| self.release_resource(oid, ref.refcount) |
| |
| @staticmethod |
| def __on_cached_object_collected(ref): |
| # pylint: disable=protected-access |
| try: |
| connection_wr = ref.connection_wr |
| if connection_wr: |
| connection = connection_wr() |
| if connection: |
| with connection._lock: |
| connection.__clear_stale_obj_locked(ref) |
| except: |
| die_due_to_fatal_exception("destroying orphaned object") |
| |
| def __addref_cached_obj_locked(self, oid): |
| ref = self.__obj_cache.get(oid) |
| if ref is None: |
| obj = None |
| else: |
| obj = ref() |
| if obj is None: |
| assert ref.connection_wr() is self |
| self.__clear_stale_obj_locked(ref) |
| else: |
| assert ref.refcount >= 1 |
| ref.refcount += 1 |
| assert obj is None or (isinstance(obj, SharedObject) and obj.oid == oid) |
| return obj |
| |
| def internal_cache_newborn(self, obj_type, obj_dict): |
| """Called during unpickling of freshly-shipped resources. |
| |
| Build a new instance of OBJ_TYPE from OBJ_DICT, but without going |
| through the object's constructor. Cache the new object locally by |
| its OID and return the new object. |
| """ |
| with self._lock: |
| assert issubclass(obj_type, SharedObject) |
| obj = object.__new__(obj_type) |
| obj.__dict__.update(obj_dict) |
| oid = obj.oid |
| # Sometimes, we can be shipped a copy of an object we already |
| # have. In this case, absorb the new reference, discard the |
| # object we were just shipped, and return the canonical copy of |
| # the object for this process. |
| existing_obj = self.__addref_cached_obj_locked(obj.oid) |
| if existing_obj: |
| return existing_obj |
| assert obj.oid not in self.__obj_cache |
| callback = ClientConnectionBase.__on_cached_object_collected |
| self.__obj_cache[oid] = ObjCacheWeakRef(obj, callback, self, oid) |
| # pylint: disable=protected-access |
| obj._resman_after_pull_hook() |
| return obj |
| |
| def load_resource(self, oid): |
| """Resolve an OID we receive over a socket connection. |
| |
| Called by the pickle machinery during deserialization. |
| This function sinks a reference. |
| |
| Return the SharedObject corresponding to OID. |
| """ |
| assert self is get_process_resmanc(), \ |
| "only the process resmanc can own resources" |
| with self._lock: |
| obj = self.__addref_cached_obj_locked(oid) |
| if not obj: |
| # We don't know about this object, so we need to get a copy. |
| # adopt_resource pickles the object in such a way that |
| # unpickling it while reading adopt_resource's return value |
| # ends up calling cache_newborn, which inserts the object into |
| # the object cache. We just check that it's done its job. |
| obj = self.adopt_resource(oid) |
| assert obj.oid == oid |
| assert self.__obj_cache[oid].refcount >= 1 |
| return obj |
| |
| def _transact(self, api_nr, args, kwargs): |
| raise NotImplementedError("abstract") |
| |
| def sync_reference_counts(self): |
| """Flush any pending local reference count changes to the server""" |
| # getpid requires a round trip, so it flushes by side effect |
| self.getpid() |
| |
| class DeathCallbackKey(object): |
| """Token used to remember how to un-add death callback""" |
| def __init__(self, oid, callback): |
| assert _assert_valid_oid(oid) |
| assert callable(callback) |
| self.oid = oid |
| self.callback = callback |
| __slots__ = ["oid", "callback", "__weakref__"] |
| |
| class ClientConnection(ClientConnectionBase): |
| """Socket connection to resource manager server.""" |
| def __init__(self, client_channel, id_, name): |
| ClientConnectionBase.__init__(self, id_) |
| self.__client_channel = the(Channel, client_channel) |
| self.__name = name |
| self.fd_numbers = client_channel.fd_numbers |
| self.__death_callbacks = MultiDict() |
| self.__notification_thread = None |
| self.__notification_semaphore = None |
| self.__notifications = None |
| |
| @property |
| def name(self): |
| """Name of this ClientConnection""" |
| return self.__name |
| |
| def _do_close(self): |
| # Take the lock so that nobody can observe __client_channel being |
| # closed before we clear all our weak reference callbacks. |
| with self._lock: |
| if self.__notification_thread: |
| self.__enqueue_notification((MSG_NOTIFICATION_QUIT_THREAD, None)) |
| # Join with lock unheld because notification thread occasionally |
| # wants to take the lock. |
| if self.__notification_thread: |
| self.__notification_thread.join() |
| with self._lock: |
| if self.__death_callbacks: |
| logc.warning("ignoring pending death callbacks") |
| self.__death_callbacks.clear() |
| self.__client_channel.close() |
| self.__client_channel = None |
| super()._do_close() |
| |
| def __repr__(self): |
| return "ClientConnection(s={!r}, id={!r}, name={!r})".format( |
| self.__client_channel, |
| self.id, |
| self.name) |
| |
| def _transact(self, api_nr, args, kwargs, oneway): |
| try: |
| if __debug__: |
| logc_tx.debug("starting transaction") |
| with self._lock: |
| msg = (MSG_COMMAND, api_nr, args, kwargs) |
| self.__client_channel.send(msg, ToResmanSendContext) |
| if __debug__: |
| logc_tx.debug("transaction sent request %r", msg) |
| if oneway: |
| return None |
| del msg, args, kwargs |
| msg = self.__recv_reply_locked() |
| if __debug__: |
| logc_tx.debug("transaction recv msg=%r", msg) |
| code, payload = msg |
| assert code in (MSG_REPLY_SUCCESS, MSG_REPLY_ERROR) |
| except: |
| die_due_to_fatal_exception("talking to resource manager") |
| |
| if code == MSG_REPLY_ERROR: |
| if __debug__: |
| logc_tx.debug("transaction error %r", payload) |
| reraise_exc_info(payload) |
| assert False, "never reached" |
| if __debug__: |
| logc_tx.debug("transaction result %r", payload) |
| return payload |
| |
| def __enqueue_notification(self, msg): |
| assert msg[0] >= MSG_NOTIFICATION_MIN, \ |
| "not a notification message: {!r}".format(msg) |
| self.__notifications.append(msg) # Atomic |
| self.__notification_semaphore.release() |
| |
| def __recv_reply_locked(self, *, block=True): |
| try: |
| while True: |
| try: |
| msg = self.__client_channel.recv(ResourceRecvContext, block=block) |
| except BlockingIOError: |
| msg = None |
| else: |
| if msg[0] >= MSG_NOTIFICATION_MIN: |
| # Oops. We got a notification instead of the reply we |
| # wanted. Let the notification thread handle the |
| # notification. We shouldn't be getting notifications |
| # unless we asked for them, so the thread and its data |
| # structures should be up and running now. |
| self.__enqueue_notification(msg) |
| continue # Try getting the reply we actually want |
| return msg |
| except: |
| die_due_to_fatal_exception("talking to resource manager") |
| |
| def __recv_notifications_locked(self): |
| while True: |
| try: |
| msg = self.__client_channel.recv(block=False) |
| except BlockingIOError: |
| return |
| self.__enqueue_notification(msg) |
| |
| def __extract_callbacks_for_dead_oids_locked(self, dead_oids): |
| # Keep lock outside the function so variable references can't leak |
| # outside the lock. |
| callbacks = [] |
| for dead_oid in dead_oids: |
| for key in self.__death_callbacks.pop(dead_oid, ()): |
| callbacks.append(key.callback) |
| return callbacks |
| |
| def __dispatch_notifications(self, notifications): |
| dead_oids = [] |
| for code, payload in notifications: |
| if code != MSG_NOTIFICATION_OID_GONE: |
| if code == MSG_NOTIFICATION_QUIT_THREAD: |
| assert self.closed |
| return |
| logc.warning("unknown notification code %r", code) |
| continue |
| dead_oids.extend(payload) |
| assert len(set(dead_oids)) == len(dead_oids) |
| with self._lock: |
| callbacks = self.__extract_callbacks_for_dead_oids_locked(dead_oids) |
| callbacks.reverse() |
| while callbacks: |
| callbacks.pop()() |
| self.__client_channel.send((MSG_ACK_NOTIFICATION, len(notifications))) |
| |
| def __run_notification_thread(self): |
| """Run the actual notification listener thread""" |
| try: |
| sem = self.__notification_semaphore |
| notifications = self.__notifications |
| poll = select.poll() |
| for fd_number in self.__client_channel.fd_numbers: |
| poll.register(fd_number, select.POLLIN) |
| poll.register(sem.fileno()) |
| while True: |
| ready_fd = [fd for fd, _ in poll.poll() if fd != sem.fileno()] |
| if ready_fd: |
| with self._lock: |
| self.__recv_notifications_locked() |
| nr_ready = sem.try_acquire_all() |
| if nr_ready: |
| self.__dispatch_notifications(notifications[:nr_ready]) |
| del notifications[:nr_ready] # Atomic |
| if self.closed: |
| break |
| except: |
| die_due_to_fatal_exception("running notification thread") |
| |
| def __ensure_notification_thread_locked(self): |
| """Make sure that the notification-receiver thread is running""" |
| if not self.__notification_thread: |
| self.__notification_semaphore = WaitableSemaphore() |
| self.__notifications = [] |
| self.__notification_thread = threading.Thread( |
| name=self.name + "/notif", |
| daemon=True, |
| target=self.__run_notification_thread) |
| self.__notification_thread.start() |
| |
| def add_death_callback(self, oid, callback): |
| """Add a callback CALLBACK to be made when object OID dies |
| |
| CALLBACK is called with no arguments and with no locks held from a |
| dedicated thread. |
| |
| If CALLBACK is given for a particular OID multiple times, it's |
| called that many times. |
| |
| Return an opaque handle that can be passed to |
| remove_death_callback. |
| """ |
| key = DeathCallbackKey(oid, callback) |
| with self._lock: |
| self.__ensure_notification_thread_locked() |
| if self.__death_callbacks.add(oid, key): |
| self.link_to_death_internal(oid) |
| return key |
| |
| def remove_death_callback(self, key): |
| """Remove a callback added with add_death_callback. |
| |
| Removing a callback that's not present is a no-op. |
| """ |
| assert isinstance(key, DeathCallbackKey) |
| with self._lock: |
| try: |
| last = self.__death_callbacks.remove(key.oid, key) |
| except KeyError: |
| return |
| if last: |
| self.unlink_from_death_internal(key.oid) |
| |
| class ClientConnectionFactory(SafeClosingObject): |
| """Factory for making a ClientConnection |
| |
| A factory is only good for creating a single connection. You can |
| duplicate a factory to create many client objects for the same |
| connection, which we use primarily for anycast. |
| |
| Unlike ClientConnection, ClientConnectionFactory can be sent between |
| processes. Useful for bootstrapping a new child process. |
| """ |
| def __init__(self, client_channel, id_, name=None): |
| assert the(int, id_) > 0 # pylint: disable=compare-to-zero |
| assert isinstance(client_channel, Channel) |
| assert isinstance(name, (type(None), str)) |
| self.__client_channel = client_channel |
| self.__id = id_ |
| self.__name = name |
| |
| @property |
| def id(self): |
| """ID of the connection we'll build""" |
| return self.__id |
| |
| def __call__(self, subclass=ClientConnection, **kwargs): |
| assert issubclass(subclass, ClientConnection) |
| resmanc = subclass(self.__client_channel, self.__id, self.__name, **kwargs) |
| self.__client_channel = None |
| return resmanc |
| |
| def _do_close(self): |
| if self.__client_channel: |
| self.__client_channel.close() |
| super()._do_close() |
| |
| def __repr__(self): |
| return "ClientConnectionFactory(s={!r},id={!r})".format( |
| self.__client_channel, |
| self.__id) |
| |
| class LocalClient(ClientConnectionBase): |
| """Special resource manager client used in resource manager process.""" |
| def __init__(self, server_connection): |
| assert isinstance(server_connection, ServerConnection) |
| assert server_connection.id == OID_SELF_CONNECTION |
| ClientConnectionBase.__init__(self, server_connection.id) |
| self.__server_connection = server_connection |
| |
| @property |
| def server_connection(self): |
| """The server connection corresponding to this client""" |
| return self.__server_connection |
| |
| def __local_copy(self, obj): |
| # We pickle and immediately unpickle so that we correctly |
| # translate local and remote objects, cache references, and do |
| # other work that happens during pickling and |
| # unpickling. Hopefully, we see only small messages here and the |
| # process isn't _that_ inefficient. If the resource manager is |
| # sending 20MB numpy arrays to itself, something is |
| # probably wrong. |
| |
| # TODO(dancol): check for simple cases where a round-trip through |
| # pickle won't change anything and avoid the work. Also somehow |
| # avoid the file descriptor dups. |
| pickle_data, pc = fancy_pickle(obj, ToResmanSendContext) |
| pc.fdhs = [fdh.dup() for fdh in pc.fdhs] |
| return fancy_unpickle(pickle_data, |
| pc.fdhs, |
| ResmanClientRecvContext.from_(self.server_connection)) |
| |
| def _transact(self, api_nr, args, kwargs, _oneway): |
| assert get_process_resmanc() is self |
| assert threading.current_thread() is tls.resman_server |
| try: |
| msg = self.__local_copy((MSG_COMMAND, api_nr, args, kwargs)) |
| del args, kwargs |
| # pylint: disable=protected-access |
| reply = self.__local_copy(self.__server_connection._invoke_api(msg)) |
| if not reply: |
| return None |
| del msg |
| except: |
| die_due_to_fatal_exception("internal resource manager operation") |
| code, payload = reply |
| assert code in (MSG_REPLY_SUCCESS, MSG_REPLY_ERROR) |
| if code == MSG_REPLY_ERROR: |
| try: |
| reraise_exc_info(payload) |
| finally: |
| del code, payload # Break reference cycle |
| return payload |
| |
| def fork(self, fn, process_resmanc_factory): |
| """Fork and call FN in the child. |
| |
| PROCESS_RESMANC_FACTORY makes the process resman client in the |
| child. It's a separate parameter so callers can construct it in |
| advance, wire up an apartment, and then fork. |
| |
| In the child, the brains are sucked out of the current resource |
| manager server and control returned to toplevel before execution |
| of the fork child function proceeds. |
| |
| Return the PID of the child. |
| """ |
| return self.__server_connection.fork(fn, process_resmanc_factory) |
| |
| class SharedObjectMeta(type): |
| """Metaclass that inhibits __init__ when an object is known""" |
| # __init__ runs too late to change the __init__ slot |
| def __new__(mcs, name, bases, dict_): |
| old_init = dict_.get("__init__") |
| if old_init: |
| # pylint: disable=missing-docstring |
| @wraps(old_init) |
| def new_init(self, *args, **kwargs): |
| if not self.oid: |
| old_init(self, *args, **kwargs) |
| dict_["__init__"] = new_init |
| return type.__new__(mcs, name, bases, dict_) |
| |
| class SharedObject(object, metaclass=SharedObjectMeta): |
| """Magic base class that makes a class shareable. |
| |
| Just inherit from this class and storing objects of any subtype |
| should Just Work. Specifically, plain construction of subtypes will |
| create a resource in the process connection. |
| |
| Subclass __init__ will be called only in the server process. |
| Calling SharedObject.__init__ from subclass constructors is |
| not necessary. |
| |
| Subclasses should not override __new__; object.__new__ will be used |
| to construct instances of the subtype in various circumstances. |
| """ |
| |
| # N.B. This class is either trivial or highly magical, depending on |
| # your perspective. It basically hijacks __new__ to turn object |
| # instantiation into 1) transparent instantiation on the resource |
| # manager server, followed by 2) a retrieval and de-pickling of the |
| # resulting value. |
| # |
| # We need to use a metaclass to rename __init__: if we didn't, the |
| # runtime would re-run it after our fake __new__, even client-side, |
| # resulting in double initialization. |
| # |
| # See test_object_init_run_once. |
| |
| oid = None |
| """Persistent cross-process ID for this object""" |
| |
| def __new__(cls, *args, **kwargs): |
| if not tls.unpickle_context: |
| return get_process_resmanc().make_resource(cls, args, kwargs) |
| return object.__new__(cls) |
| |
| def __reduce__(self): |
| assert isinstance(tls.pickle_context, AbstractResourceReducer) |
| return tls.pickle_context.reduce_resource(self) |
| |
| def _resman_destroy(self): |
| """Function called for resource cleanup in resman""" |
| # Override me |
| |
| def _resman_explicit_close(self): |
| """Function called in response to explicit destroy requence""" |
| # Override me |
| |
| def _resman_after_pull_hook(self): |
| """Called just after a client obtains an object""" |
| |
| # It's kind of hacky to add process-management functions to the base |
| # interface, but otherwise we'd have to teach shm.py about |
| # process.py, creating a layering cycle. |
| |
| def _resman_get_process_status(self): |
| """Return the exit status for an object""" |
| raise NotImplementedError("not a process") |
| |
| _resman_send_eager = False |
| """Send object as copy, not reference |
| |
| Sending object contents make each send more expensive, but frees the |
| receiver from having to call adopt_resource if it doesn't have the |
| object already. Setting _resman_send_eager to true is useful for |
| small objects that don't own file descriptors. |
| """ |
| |
| def cache_object_until_death(obj, oid): |
| """Cache SharedObject OBJ until OID dies""" |
| assert isinstance(obj, SharedObject) |
| assert _assert_valid_oid(oid) |
| assert obj.oid != oid |
| resmanc = get_process_resmanc() |
| if not isinstance(resmanc, LocalClient): |
| res = [obj] |
| def _cache_reclaim_callback(): |
| res.clear() |
| resmanc.add_death_callback(oid, _cache_reclaim_callback) |