blob: a15d284630044b9ea6b024e6b387c8b186196522 [file] [log] [blame]
# Copyright (C) 2020 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:#www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A slower, shittier, and buggier version of Binder"""
# pylint: disable=bare-except
import itertools
import logging
import os
import select
import threading
import weakref
from contextlib import contextmanager
from functools import wraps, partial
from multiprocessing.util import _run_after_forkers
from socket import SHUT_RD
from fcntl import ioctl
from termios import TIOCSCTTY
from .reduction import (
PickleContext,
UnpickleContext,
fancy_pickle,
fancy_unpickle,
)
from .channel import (
Channel,
MessageSendContext,
PeerDiedException,
fdhs_to_ancdata,
)
from .waitsem import WaitableSemaphore
from .util import (
CannotPickle,
ChildExitWatcher,
ClosingContextManager,
FdHolder,
MultiDict,
NoneType,
ReferenceCountTracker,
SafeClosingObject,
cached_property,
capture_exc_info,
die_due_to_fatal_exception,
is_lock_owned_by_current_thread,
once,
rename_code_object,
reraise_exc_info,
the,
tls,
unix_pipe,
)
logc = logging.getLogger(__name__ + ".client") # pylint: disable=invalid-name
logc_tx = logging.getLogger(__name__ + ".client.tx") # pylint: disable=invalid-name
logr = logging.getLogger(__name__ + ".resman") # pylint: disable=invalid-name
# Maximum number of notifications we batch into a single notification
# packet. Should be small enough that the resulting pickled message
# still fits into a MAX_MSG_SIZE-byte datagram.
MAX_NOTIFICATION_BATCH = 32
# Maximum number of notification packets we allow to sit in a resmanc
# socket buffer before we require explicit acknowledgment.
# TODO(dancol): use a semaphore instead of explicit windowing?
MAX_NOTIFICATION_UNACKED = 2
MSG_REPLY_SUCCESS = 0
MSG_REPLY_ERROR = 1
MSG_COMMAND = 2
MSG_ACK_NOTIFICATION = 3
MSG_NOTIFICATION_MIN = 4
MSG_NOTIFICATION_OID_GONE = MSG_NOTIFICATION_MIN + 0
MSG_NOTIFICATION_QUIT_THREAD = MSG_NOTIFICATION_MIN + 1
OID_INVALID = 0 # So bool evaluation of UIDs doesn't get confused
OID_SELF_CONNECTION = 1 # Resource manager's connection to itself
OID_INITIAL_CONNECTION_ID = 2 # Initial resource manager client
# Useful hack to run everything in a single process for better
# logging, debugging, test startup performance, or something.
# Normally prefer the most robust multi-process approach.
USE_RESMANC_THREAD = os.getenv("MODERNMP_USE_RESMANC_THREAD") == "1"
def _configure_resman_logging():
shm_verbose = int(os.getenv("MODERNMP_SHM_VERBOSE", "0"))
logr.setLevel(logging.INFO)
logc.setLevel(logging.INFO)
logc_tx.setLevel(logging.INFO)
if shm_verbose > 0: # pylint: disable=compare-to-zero
logr.setLevel(logging.DEBUG)
logc.setLevel(logging.DEBUG)
if shm_verbose > 1:
logc_tx.setLevel(logging.DEBUG)
_configure_resman_logging()
_resman_api_functions = [] # pylint: disable=invalid-name
_resman_api_counts = None
class ResourceManagerError(Exception):
"""Base exception for expected resource manager errors
Any exception derived from this one is propagated from the resource
manager to clients; other exceptions instantly kill both the
resource manager process and (indirectly, through resource manager
death) all clients.
"""
class ConnectionNotFoundError(ResourceManagerError):
"""Exception thrown when we can't locate a connection"""
class ResourceConstructionFailedError(ResourceManagerError):
"""Exception thrown when building a resource fails.
The original exception is chained."""
class SendMsgError(ResourceManagerError):
"""Exception thrown when sendmsg goes wrong"""
class _DoForkChildException(BaseException):
"""Special exception to break out of event loop in fork child"""
def __init__(self, info):
super().__init__("fork child")
self.info = info
def _assert_valid_oid(oid):
assert the(int, oid) > 0 # pylint: disable=compare-to-zero
return True
@once()
def get_global_process_resmanc():
"""Get the process-global resource manager"""
name = "resman/{}".format(os.getpid())
return (start_resource_manager_thread if USE_RESMANC_THREAD \
else start_resource_manager_process)(name)
@once()
def get_temp_name_sequence():
"""Get an infinite generator of random names"""
# pylint: disable=protected-access
from tempfile import _RandomNameSequence
return _RandomNameSequence()
if USE_RESMANC_THREAD:
# Needed a TLS lookup so that the resmanc thread can use a different
# client instance when talking to itself.
def get_process_resmanc():
"""Get the main process connection to the resource manager
Despite the name, may be thread-local. A thread-local resource
manager is useful primarily when running the resource manager as a
separate thread in the current process, which we do when we run the
resource manager tests.
"""
return tls.resmanc or get_global_process_resmanc()
else:
# Not using threaded mode, so can always just use the global, even
# in the resman server.
get_process_resmanc = get_global_process_resmanc
def get_oid(obj):
"""Return the OID of OBJ if managed, else None."""
# TODO(dancol): get rid of this function
assert isinstance(obj, SharedObject)
return obj.oid
def _resman_api(*, oneway=False):
def _decorator(fn):
fn.oneway = oneway
_resman_api_functions.append(fn)
return fn
return _decorator
def _do_cache_newborn(obj_type, obj_dict):
return get_process_resmanc().internal_cache_newborn(
obj_type, obj_dict)
class _RegisterNewbornThunk(object):
"""Protect a newborn object from reference pickling.
In functions like ServerConnection.adopt_resource, we return a
resource value to the client. The return value of this function gets
pickled along with the rest of the response payload in
ServerConnection.__send. Ordinarily, when we pickle an object we
know about, we encode it as a reference. The client, in the process
of unpickling the reference, talks to the resource manager and gets
the real object payload. If we went through this pickling process
for the return value of ServerConnection.adopt_resource, we'd
recurse forever.
Instead, we return an instance of this object from ServerConnection
functions that return resources "by value". A _RegisterNewbornThunk
pickles to a special callback function that causes the recipient to
register the object instance in its own process resource manager
client. In addition, we protect the referenced object from being
pickled as a pointer, preventing infinite recursion during
unpickling.
"""
def __init__(self, obj):
assert isinstance(obj, SharedObject)
assert obj.oid
self.obj = obj
def __reduce__(self):
obj = self.obj
return _do_cache_newborn, (type(obj), obj.__dict__)
__slots__ = ["obj"]
class ServerConnection(CannotPickle):
"""A ResourceManagerServer-side connection to a client"""
__dead = False
"""Whether this connnection has failed and is being cleaned up"""
__nr_notifications_unacked = 0
"""Number of notifications the peer hasn't acknowledged"""
__critical = False
"""Kill the whole server if this connection dies"""
def __init__(self, channel, server, name):
self.__channel = the((Channel, NoneType), channel)
self.__name = the(str, name)
self.__server = the(ResourceManagerServer, server)
self.__id = the(int, server.allocate_oid())
self.__references = ReferenceCountTracker()
self.__subscribed = set()
self.__notifications = []
self.fd_numbers = channel.fd_numbers if channel else ()
@property
def name(self):
"""Name of this server connection"""
return self.__name
@cached_property
def id(self):
"""The ID of this connection"""
return self.__id
@property
def dead(self):
"""Whether this connection is marked for removal"""
return self.__dead
def __lookup_conn(self, connection_id):
"""Find a connection by id.
CONNECTION_ID should be a numeric connection id. If
it is None, return self instead."""
if connection_id:
return self.__server.find_connection_by_id(connection_id)
return self
def fork(self, fn, process_resmanc_factory):
"""Fork the resman process: see LocalClient.fork()"""
return self.__server.fork(fn, process_resmanc_factory)
@_resman_api()
def make_connection_factory(self, name):
"""Make a new resource manager connection factory.
We return a factory instead of a connection because the factory
can be sent between processes, while a full connection cannot
be. The returned factory is good for building one
ServerConnection; it is useless after being called.
"""
assert isinstance(name, str)
# On error, the client socket disappears and we clean up elegantly
client_channel, server_channel = Channel.make_pair()
server_connection = self.__server.add_connection(server_channel, name)
client_connection_factory = ClientConnectionFactory(
client_channel,
server_connection.id,
name)
if __debug__:
logr.debug("new connection factory %r", client_connection_factory)
return client_connection_factory
@_resman_api()
def make_resource(self, resource_type, args, kwargs):
"""Allocate a resource in the resman process.
RESOURCE_TYPE is a subclass of SharedObject. ARGS and KWARGS are
given to its __init__ to create a new object. If constructing the
resource object raise an exception, raise a
ResourceConstructionFailedError with the original exception as its
__cause__.
The pickle facility of the resource object will be used only when
transferring a resource instance from the resource manager to a
caller; when processes send resource objects to each other, they
send resource pointers. These resource pointers magically retrieve
their referenced objects from the resource manager
during unpickling.
Return an instance of the resource's client object. The calling
connection owns a single reference to the resource.
"""
assert issubclass(resource_type, SharedObject)
try:
resource = object.__new__(resource_type)
resource_type.__init__(resource, *args, **kwargs)
except Exception as ex:
raise ResourceConstructionFailedError() from ex
oid = self.__server.add_resource(resource)
if __debug__:
logr.debug("made resource of type %r oid=%r", resource_type, oid)
assert oid == resource.oid
assert self.__server.get_refcount(oid)
old_rc = self.__references.addref(oid)
assert not old_rc
return _RegisterNewbornThunk(resource)
def add_local_refcount(self, oid):
"""Add a reference to resource OID"""
assert _assert_valid_oid(oid)
if __debug__:
logr.debug("add_local_refcount owner=%r oid=%r newrc=%r",
self.id, oid, self.__references.get(oid, 0) + 1)
if not self.__references.addref(oid):
self.__server.addref_unchecked(oid)
def add_local_refcounts(self, oids):
"""Batch add reference counts"""
for oid in oids:
self.add_local_refcount(oid)
def release_local_refcounts(self, oids):
"""Batch release"""
for oid in oids:
self.release_resource(oid)
@_resman_api(oneway=True)
def addref_for_test(self, oid, target_oid=None):
"""Test function: add an artificial reference count to OID"""
assert self.has_reference(oid)
self.__lookup_conn(target_oid).add_local_refcount(oid)
@_resman_api()
def adopt_resource(self, oid):
"""Retrieve an instance of an existing resource.
Return (from client's perspective) a copy of the resource's held
object, wired up to call _do_cache_newborn. The caller's
reference to OID, which must exist, is reused. The resource's
reference count is unchanged.
"""
if __debug__:
logr.debug("adopting oid=%r", oid)
assert self.has_reference(oid)
return _RegisterNewbornThunk(self.__server.get_resource_unchecked(oid))
@_resman_api(oneway=True)
def link_to_death_internal(self, oid):
"""Receive a notification when resource OID disappears.
The caller does not need to have a reference to OID at the time of
call. If OID isn't currently alive, send a death
notification immediately.
Idempotent in the OID-is-alive case: no matter how many times
link_to_death is called for a given OID, only one death
notification is received.
"""
assert self.__channel, "LocalClient cannot be notified"
if oid not in self.__subscribed:
if self.__server.add_subscriber(oid, self):
self.__subscribed.add(oid)
else:
self.notify(MSG_NOTIFICATION_OID_GONE, oid)
@_resman_api(oneway=True)
def unlink_from_death_internal(self, oid):
"""Undo the effect of link_to_death_internal.
If no death link exists, do nothing.
"""
if oid in self.__subscribed:
self.__subscribed.remove(oid)
self.__server.remove_subscriber_unchecked(oid, self)
@_resman_api(oneway=True)
def release_resource(self, oid, decrement_by=1):
"""Drop references to OID.
The calling connection must have a reference to the resource.
"""
if __debug__:
logr.debug("removing %r references from OID %r", decrement_by, oid)
if not self.__references.release(oid, decrement_by):
self.__server.release_resource_unchecked(oid)
def has_reference(self, oid):
"""Return whether this connection has a reference to OID"""
return oid in self.__references
@_resman_api()
def get_refcounts(self, oid, resmanc=None):
"""Return reference counts for OID.
Return a tuple (LOCAL_REFCOUNT, GLOBAL_REFCOUNT).
USE FOR TESTING ONLY.
Connections manage their reference counts internally. All the
references a particular connection owns for a resource are
represented by a single increment of the global reference count
for that resource. Consequently, a resource's global reference
count is the number of connections that refer to it.
"""
# pylint: disable=protected-access
assert _assert_valid_oid(oid)
conn = self.__lookup_conn(resmanc)
return (conn.__references.get(oid, 0),
conn.__server.get_refcount(oid))
@_resman_api()
def getpid(self): # pylint: disable=no-self-use
"""Return the PID of the resource manager process."""
return os.getpid()
@_resman_api()
def get_exit_fd(self):
"""Return a file descriptor that we can use to wait for server exit
Works whether the server is a process or a thread.
"""
return self.__server.get_exit_fd()
@_resman_api()
def snapshot_api_counts(self):
"""Start tracking API call counts for test purposes"""
return self.__server.snapshot_api_counts()
@_resman_api(oneway=True)
def explicit_resource_close(self, oid):
"""Run early resource cleanup for resource OID
The calling connection must have a reference to the resource.
"""
assert oid in self.__references
self.__server.explicit_resource_close(oid)
@_resman_api()
def make_shared_memory_fd(self): # pylint: disable=no-self-use
"""Make a shared memory segment and return an FdHolder for it"""
# We can make and unlink immediately here, so there's no need for
# resman-side cleanup via SharedObject.
from posix_ipc import SharedMemory, ExistentialError
shm = None
while not shm:
name = "/mmp.shm." + next(get_temp_name_sequence())
try:
shm = SharedMemory(name, os.O_CREAT | os.O_EXCL)
except ExistentialError:
pass
fdh = FdHolder.steal(shm.fd)
assert not os.get_inheritable(shm.fd)
shm.unlink()
return fdh
@_resman_api()
def cg_sendmsg_nonblock(self, cg_oid, data, fdhs, resources):
"""Atomic message post and refcount manipulation
CG_OID is the OID of a CallGate object, defined in apartment.py.
The caller must have a reference to this object.
On successful send, return True. If the send would block, return
False. On socket errors, raise SendMsgError with the real error
chained in __cause__. If the CallGate object's target has ceased
to exist, raise a SendMsgError chained to PeerDiedException.
"""
assert self.has_reference(cg_oid)
assert all(self.has_reference(oid) for oid in resources)
cg = self.__server.get_resource_unchecked(cg_oid)
try:
dst = self.__server.find_connection_by_id_unchecked(cg.process_oid)
except KeyError:
raise SendMsgError from PeerDiedException()
dst.add_local_refcounts(resources)
try:
cg.channel.sendmsg_raw(data, fdhs_to_ancdata(fdhs), False)
except BlockingIOError:
dst.release_local_refcounts(resources)
return False
except OSError as ex:
dst.release_local_refcounts(resources)
raise SendMsgError from ex
return True
@_resman_api()
def get_critical(self):
"""Get the critical flag for this connection"""
return self.__critical
@_resman_api(oneway=True)
def set_critical(self, is_critical: bool):
"""Set whether this connection brings down the whole service"""
self.__critical = is_critical
@_resman_api()
def get_process_status(self, oid):
"""Get the exit status of process OID.
OID must refer to a process object.
This function is useful mostly for testing that process reaping is
actually working.
"""
assert self.has_reference(oid)
# pylint: disable=protected-access
return self.__server \
.get_resource_unchecked(oid) \
._resman_get_process_status()
def force_close(self):
"""Force close a connection"""
if not self.dead and self.__channel:
self.__channel.shutdown(SHUT_RD)
def force_drop_references(self):
"""Forcibly dereference everything this connection owns.
Used during connection teardown.
"""
assert self.dead
for subscribed_oid in list(self.__subscribed):
self.unlink_from_death_internal(subscribed_oid)
assert not self.__subscribed # Mutated above
for oid in tuple(self.__references):
del self.__references[oid]
self.__server.release_resource_unchecked(oid)
def final_cleanup_on_death(self):
"""Last method called by server before forgetting object"""
if self.__channel:
self.__channel.close()
if self.__critical:
self.__server.force_close_all_connections()
def mark_dead(self):
"""Start connection cleanup."""
if not self.__dead:
self.__dead = True
self.__notifications.clear()
self.__server.on_connection_dead(self)
def __recv(self, *, block=True):
try:
return self.__channel.recv(
ResmanClientRecvContext.from_(self),
block=block)
except BlockingIOError:
return None
except PeerDiedException:
self.mark_dead()
return None
def __send(self, obj, block=True):
try:
return self.__channel.send(obj,
FromResmanSendContext.to(self),
block=block)
except PeerDiedException:
self.mark_dead()
return False
def notify(self, code, oid):
"""Enqueue a notification to be delivered"""
assert self.__channel
if not self.__dead:
if (self.__notifications and
self.__notifications[-1][0] == code and
len(self.__notifications[-1][1]) < MAX_NOTIFICATION_BATCH):
self.__notifications[-1][1].append(oid)
else:
self.__notifications.append((code, [oid]))
self.send_pending_notifications()
def send_pending_notifications(self):
"""Send some notifications, if any, to socket"""
while (self.__notifications and
self.__nr_notifications_unacked < MAX_NOTIFICATION_UNACKED and
self.__send(self.__notifications[0], block=False)):
if __debug__:
logr.debug("sent notification to socket")
self.__nr_notifications_unacked += 1
self.__notifications.pop(0)
def _invoke_api(self, msg):
_, api_nr, args, kwargs = msg
fn = _resman_api_functions[api_nr]
if __debug__:
logr.debug("API call from client %s to %s",
self.id, fn.__name__)
if fn.oneway:
fn(self, *args, **kwargs)
return None
try:
return (MSG_REPLY_SUCCESS, fn(self, *args, **kwargs))
except ResourceManagerError:
return (MSG_REPLY_ERROR, capture_exc_info())
def handle_transaction(self):
"""Receive message from client and reply.
Return whether this connection should stay alive.
"""
while True:
msg = self.__recv(block=False)
if msg is None:
break
elif msg[0] == MSG_COMMAND:
reply = self._invoke_api(msg)
if reply:
self.__send(reply)
elif msg[0] == MSG_ACK_NOTIFICATION:
_, nr_acked = msg
assert isinstance(nr_acked, int)
assert nr_acked <= self.__nr_notifications_unacked
self.__nr_notifications_unacked -= nr_acked
else:
raise AssertionError("unknown message {!r}".format(msg))
def __repr__(self):
return "ServerConnection(name={!r}, id={!r}, channel={!r})".format(
self.name, self.id, self.__channel)
def _rebuild_resource(oid):
assert isinstance(tls.unpickle_context, ResourceRecvContext)
return tls.unpickle_context.rebuild_resource(oid)
class ShmSharedMemorySendContextMixin(MessageSendContext):
"""Message sending mixin that makes shared memory with resman"""
@staticmethod
def make_shared_memory_fd() -> FdHolder:
"""Make a memory file descriptor"""
# TODO(dancol): use Linux memfd where available
return get_process_resmanc().make_shared_memory_fd()
class AbstractResourceReducer(PickleContext):
"""Interface for reducing resources in pickling"""
def reduce_resource(self, obj):
"""__reduce__ for a SharedObject"""
raise NotImplementedError("abstract")
class ResourceSendingMessageContextMixin(AbstractResourceReducer):
"""Message sending mixin that accumulates a list of resources to send"""
def __init__(self):
super().__init__()
self.resources = []
def reduce_resource(self, obj):
"""__reduce__ for a SharedObject"""
# pylint: disable=protected-access
self.resources.append(obj.oid)
if obj._resman_send_eager:
return _RegisterNewbornThunk(obj).__reduce__()
return _rebuild_resource, (obj.oid,)
class ToResmanSendContext(
ShmSharedMemorySendContextMixin,
MessageSendContext,
AbstractResourceReducer,
):
"""Message sending context for sending messages to resman itself"""
def reduce_resource(self, obj):
"""__reduce__ for a SharedObject"""
# No need to do anything here: we don't have to tell the server to
# add a reference when it has all possible references already.
return _rebuild_resource, (obj.oid,)
class FromResmanSendContext(
ShmSharedMemorySendContextMixin,
MessageSendContext,
ResourceSendingMessageContextMixin,
):
"""Message sending context that adds references to a connection
Usable only from the resman server itself.
"""
def __init__(self, to):
super().__init__()
self.__to = the(ServerConnection, to)
@classmethod
def to(cls, to):
"""Make a constructing function"""
return partial(cls, to)
def sendmsg(self, pickle_data, block, channel):
self.__to.add_local_refcounts(self.resources)
try:
super().sendmsg(pickle_data, block, channel)
except:
self.__to.release_local_refcounts(self.resources)
raise
class ResourceRecvContext(UnpickleContext):
"""Receive context for receiving messages"""
@staticmethod
def rebuild_resource(oid):
"""Translate a pickled OID into an object"""
return get_process_resmanc().load_resource(oid)
class ResmanClientRecvContext(ResourceRecvContext):
"""Receive resources from a resman client connection"""
def __init__(self, *, from_, **kwargs):
super().__init__(**kwargs)
self.__from = the(ServerConnection, from_)
@classmethod
def from_(cls, from_):
"""Make a constructing function"""
return partial(cls, from_=from_)
def rebuild_resource(self, oid):
# We're the resource server, so there's no point in the client
# sending a separate message to addref ourselves. Just add
# virtual references on demand.
assert self.__from.has_reference(oid)
get_process_resmanc().server_connection.add_local_refcount(oid)
return super().rebuild_resource(oid)
class ResourceManagerServer(object):
"""Provides resource manager server side"""
def __init__(self):
self.__connections_by_fileno = {}
self.__connections_by_id = {}
self.__resources = {}
self.__refcounts = {}
self.__subscribers = MultiDict()
self.__dead_connections = []
self.__poll = select.poll() # TODO(dancol): use selectors
self.__oid_gen = itertools.count(1)
self.__randname_sequence = None
self.__exit_fd = None
def allocate_oid(self):
"""Make a new unique ID for connections, resources, etc."""
return next(self.__oid_gen)
def __all_connections(self):
return self.__connections_by_id.values()
def add_connection(self, server_channel, name) -> ServerConnection:
"""Add a new connection.
SERVER_CHANNEL is the server side of the channel pair created
using Channel.make_pair().
"""
connection = ServerConnection(server_channel, self, name)
if server_channel:
for fd_number in connection.fd_numbers:
assert fd_number not in self.__connections_by_fileno
self.__connections_by_fileno[fd_number] = connection
self.__poll.register(fd_number, select.POLLIN)
self.__connections_by_id[connection.id] = connection
if __debug__:
logr.debug("added connection fd_numbers=%r id=%r",
connection.fd_numbers, connection.id)
return connection
def add_resource(self, resource):
"""Add a new resource RESOURCE to the server.
RESOURCE can be of any type derived from SharedObject.
Return the new resource ID.
"""
assert isinstance(resource, SharedObject)
assert not resource.oid
resource.oid = oid = self.allocate_oid()
assert oid not in self.__resources
self.__refcounts[oid] = 1
self.__resources[oid] = resource
return oid
def addref_unchecked(self, oid):
"""Add a reference to resource OID"""
assert oid in self.__resources
logr.debug("adding global refcount oid=%r newrc=%r", oid,
self.__refcounts.get(oid, 0) + 1)
self.__refcounts[oid] += 1
def has_resource(self, oid):
"""Return whether we know about resource OID"""
assert _assert_valid_oid(oid)
return oid in self.__resources
def add_subscriber(self, oid, connection):
"""Notify CONNECTION when OID disappears
If OID corresponds to a resource or a connection, set up the
subscription link. Otherwise, return false.
Subscription additions are idempotent on success.
"""
assert isinstance(connection, ServerConnection)
if oid in self.__refcounts:
assert oid in self.__resources
if __debug__:
logr.debug("linking resource OID %r -> %r", oid, connection)
elif oid in self.__connections_by_id:
if __debug__:
logr.debug("linking connection OID %r -> %r", oid, connection)
else:
if __debug__:
logr.debug("FAILED linking OID %r: not resource or connection", oid)
return False
self.__subscribers.add(oid, connection)
return True
def remove_subscriber_unchecked(self, oid, connection):
"""Remove notification of resource death
OID and CONNECTION are the same as for add_subscriber().
"""
self.__subscribers.remove(oid, connection)
def get_refcount(self, oid):
"""Return the reference count of resource OID"""
return self.__refcounts.get(oid, 0)
def get_resource_unchecked(self, oid):
"""Return the resource object for id OID."""
return self.__resources[oid]
def __death_notify(self, oid):
subscribers = self.__subscribers.get(oid)
if subscribers:
for subscriber in list(subscribers):
subscriber.notify(MSG_NOTIFICATION_OID_GONE, oid)
subscriber.unlink_from_death_internal(oid)
assert not subscribers # unlink_from_death mutated
assert oid not in self.__subscribers
def release_resource_unchecked(self, oid):
"""Drop a reference to resource OID"""
assert oid in self.__resources
refcount = self.__refcounts[oid] - 1
if __debug__:
logr.debug("dropping global reference to OID %r: new refcount %r",
oid, refcount)
if refcount:
assert refcount > 0 # pylint: disable=compare-to-zero
self.__refcounts[oid] = refcount
else:
if __debug__:
for connection in self.__all_connections():
assert not connection.has_reference(oid), \
"connection {} has stray reference to OID {}".format(
connection, oid)
self.__death_notify(oid)
try:
# pylint: disable=protected-access
self.__resources[oid]._resman_destroy()
except:
die_due_to_fatal_exception("destroying resource")
del self.__refcounts[oid]
del self.__resources[oid]
def explicit_resource_close(self, oid):
"""Explicitly clean up a resource.
The resource, after close, still exists, in the sense that its OID
remains valid. Its state, however, may transition.
"""
# pylint: disable=protected-access
self.__resources[oid]._resman_explicit_close()
def find_connection_by_id_unchecked(self, connection_id):
"""Find a connection with ID
Raise KeyError if connection is not found"""
return self.__connections_by_id[connection_id]
def find_connection_by_id(self, connection_id):
"""Find a connection with ID.
If the connection isn't found, raise ConnectionNotFoundError
"""
try:
return self.__connections_by_id[connection_id]
except KeyError:
raise ConnectionNotFoundError(connection_id) from None
def get_exit_fd(self):
"""Return a file descriptor usable for monitoring server exit"""
if not self.__exit_fd:
self.__exit_fd = unix_pipe()
return self.__exit_fd[0]
def fork(self, fn, process_resmanc_factory):
"""See LocalClient.fork()"""
assert not USE_RESMANC_THREAD
assert tls.resman_server
# TODO(dancol): use FD inheritance here instead of a command socket
command_read, command_write = Channel.make_pair(duplex=False)
pid = _do_fork()
if not pid:
command_write.close()
raise _DoForkChildException([process_resmanc_factory, command_read])
process_resmanc_factory.close()
child_connection = self.find_connection_by_id_unchecked(
process_resmanc_factory.id)
command_read.close()
command_write.send(fn, FromResmanSendContext.to(child_connection))
command_write.close()
return pid
def force_close_all_connections(self):
"""Force all connections closed.
Used when a critical connection dies to force all other
connections to close.
"""
for connection in self.__connections_by_id.values():
connection.force_close()
def on_connection_dead(self, connection):
"""Called by connections becoming dead to queue final death"""
assert connection.dead
assert connection not in self.__dead_connections
if __debug__:
logr.debug("Adding connection %r to dead connections", connection)
self.__dead_connections.append(connection)
def __remove_dead_connection(self, connection):
if __debug__:
logr.debug("removing dead connection fd_numbers=%r id=%r",
connection.fd_numbers, connection.id)
assert connection.dead
connection.force_drop_references()
self.__death_notify(connection.id)
del self.__connections_by_id[connection.id]
for fd_number in connection.fd_numbers:
del self.__connections_by_fileno[fd_number]
self.__poll.unregister(fd_number)
connection.final_cleanup_on_death()
def __flush_dead_connections(self):
for connection in self.__dead_connections:
self.__remove_dead_connection(connection)
self.__dead_connections.clear()
def __run_main_loop(self):
while self.__connections_by_fileno:
if __debug__:
logr.debug("loop conn=%r", self.__connections_by_fileno)
timeout = None
for connection in self.__all_connections():
connection.send_pending_notifications()
if self.__dead_connections:
self.__flush_dead_connections()
timeout = 0 # Keep looping until we're stable
poll_result = self.__poll.poll(timeout)
for fd, _ in poll_result:
connection = self.__connections_by_fileno[fd]
connection.handle_transaction()
@staticmethod
def snapshot_api_counts():
"""Start tracking API call counts: useful during tests
If tracking is not yet enabled, enable it.
"""
global _resman_api_functions
global _resman_api_counts
if _resman_api_counts is None:
_resman_api_counts = ReferenceCountTracker()
def _wrap(fn):
@wraps(fn)
def _counting_wrapper(*args, **kwargs):
_resman_api_counts.addref(fn.__name__)
return fn(*args, **kwargs)
return _counting_wrapper
_resman_api_functions = [_wrap(fn) for fn in _resman_api_functions]
old_counts = dict(_resman_api_counts)
_resman_api_counts.clear()
return old_counts
@staticmethod
@contextmanager
def __set_up_local_client(self_connection):
if USE_RESMANC_THREAD:
assert not tls.resmanc
with tls.set_resmanc(LocalClient(self_connection)):
del self_connection
yield
tls.resmanc.close()
else:
assert not get_global_process_resmanc.has_value()
get_global_process_resmanc.set(LocalClient(self_connection))
del self_connection
try:
yield
finally:
get_global_process_resmanc().close()
get_global_process_resmanc.reset()
def run(self):
"""Run the ResourceManagerServer thread.
Terminates when all connections (including the initial connection,
excluding the self connection) go away.
"""
# pylint: disable=bad-continuation
with tls.set_resman_server(threading.current_thread()), \
self.__set_up_local_client(
self.__connections_by_id[OID_SELF_CONNECTION]):
self.__run_main_loop()
assert len(self.__connections_by_id) == 1
self_connection = self.__connections_by_id[OID_SELF_CONNECTION]
self_connection.mark_dead()
assert self.__dead_connections == [self_connection]
self.__flush_dead_connections()
assert not self.__resources
assert not self.__dead_connections
assert not self.__connections_by_id
assert not self.__connections_by_fileno
if self.__exit_fd:
os.write(self.__exit_fd[1].fileno(), b".")
def _run_resource_manager(initial_server_channel_byref):
if __debug__:
logr.debug("resource manager starting")
initial_server_channel, = initial_server_channel_byref
initial_server_channel_byref.clear()
server = ResourceManagerServer()
self_connection = server.add_connection(None, "self")
assert self_connection.id == OID_SELF_CONNECTION
del self_connection
initial_connection = server.add_connection(
initial_server_channel, "initial")
initial_connection.set_critical(True)
assert initial_connection.id == OID_INITIAL_CONNECTION_ID
del initial_server_channel
del initial_connection
try:
server.run()
except _DoForkChildException:
server.__dict__.clear() # Break GC cycles
raise
if __debug__:
logr.debug("resource manager exiting")
def _do_fork():
"""Call os.fork(), running registered cleanups in child"""
assert threading.active_count() == 1, "must not fork with threads"
pid = os.fork()
if not pid:
_run_after_forkers()
return pid
def _daemonize():
if _do_fork():
os._exit(0) # pylint: disable=protected-access
os.setsid()
os.chdir("/")
def _run_fork_child(fork_info):
assert not get_global_process_resmanc.has_value()
assert not tls.resmanc
assert not hasattr(tls, "resman_server")
assert not tls.unpickle_context
assert not tls.active_calls
process_resmanc_factory, command_read = fork_info
fork_info.clear() # Drop references
del fork_info
get_global_process_resmanc.set(process_resmanc_factory())
del process_resmanc_factory
fn = command_read.recv(ResourceRecvContext)
command_read.close()
del command_read
return fn()
def _run_resource_manager_thread(initial_server_channel_byref):
try:
_run_resource_manager(initial_server_channel_byref)
except:
die_due_to_fatal_exception("running resource manager")
def _openpty():
master, slave = os.openpty()
return FdHolder.steal(master), FdHolder.steal(slave)
def _run_resource_manager_process(initial_server_channel_byref):
log_verbose = int(os.getenv("MODERNMP_RESMAN_SERVER_VERBOSE", "0")) >= 1
logging.basicConfig(
level=logging.DEBUG if log_verbose else logging.INFO,
format="%(process)d/%(threadName)s %(levelname)s:%(name)s:%(message)s")
_daemonize()
master, slave = _openpty()
ioctl(slave, TIOCSCTTY, 0) # TODO(dancol) Fix in resmanc!
ChildExitWatcher.install()
try:
return _run_resource_manager(initial_server_channel_byref)
except _DoForkChildException as ex:
fork_info = ex.info
del master, slave
ChildExitWatcher.uninstall()
# Fork children end up here. We try to keep the child stack as
# shallow as possible. (That's part of the reason we bubble control
# flow up to here instead of just doing the normal C thing of
# running random stuff below the call to fork(2) and calling _exit()
# before anyone tries to return into the ancient pre-fork
# stack frames.)
return _run_fork_child(fork_info)
def start_resource_manager_process(name):
"""Start a resource manager process.
NAME is None or the name of the resource manager client.
"""
# TODO(dancol): start by fork if we get here early enough
client_channel, server_channel = Channel.make_pair()
from .spawn import start_subprocess
with start_subprocess(partial(_run_resource_manager_process,
[server_channel]), name=name) as child:
# Child will daemonize immediately, so we shouldn't wait long
child.wait()
if child.returncode:
raise RuntimeError("resource manager process failed to start")
server_channel.close()
return ClientConnectionFactory(
client_channel,
OID_INITIAL_CONNECTION_ID,
name)()
def start_resource_manager_thread(name):
"""Start the resource manager as a thread.
NAME is None or a string name for the resource manager.
USE ONLY IN TESTS.
"""
client_channel, server_channel = Channel.make_pair()
server_channel_byref = [server_channel]
del server_channel
thread = threading.Thread(
target=_run_resource_manager,
args=(server_channel_byref,),
daemon=True)
thread.start()
factory = ClientConnectionFactory(
client_channel,
OID_INITIAL_CONNECTION_ID,
name)
client_connection = factory()
return client_connection
def _make_resman_api_thunk(api_nr, api_fn, cls):
# pylint: disable=protected-access
oneway = api_fn.oneway
def _thunk(self, *args, **kwargs):
return self._transact(api_nr, args, kwargs, oneway)
api_name = api_fn.__name__
_thunk.__code__ = rename_code_object(_thunk.__code__, api_name)
_thunk.__name__ = api_name
_thunk.__qualname__ = cls.__qualname__ + "." + api_name
return _thunk
def _add_resman_api_thunks(cls):
for api_nr, api_fn in enumerate(_resman_api_functions):
thunk = _make_resman_api_thunk(api_nr, api_fn, cls)
assert not hasattr(cls, api_fn.__name__)
setattr(cls, api_fn.__name__, thunk)
return cls
class ObjCacheWeakRef(weakref.ref):
"""Weak reference that tracks a connection and OID"""
def __new__(cls, ob, callback, connection, oid):
assert isinstance(connection, ClientConnectionBase)
assert _assert_valid_oid(oid)
self = weakref.ref.__new__(cls, ob, callback)
self.connection_wr = weakref.ref(connection)
self.oid = oid
self.refcount = 1
return self
def __init__(self, ob, callback, _connection, _oid):
super().__init__(ob, callback)
__slots__ = "connection_wr", "oid", "refcount"
@_add_resman_api_thunks
class ClientConnectionBase(ClosingContextManager, CannotPickle):
"""Base for classes talking to the resource manager."""
def __init__(self, id_):
assert the(int, id_) > 0 # pylint: disable=compare-to-zero
self._lock = threading.RLock() # Subclasses take lock
self.__id = id_
self.__obj_cache = {}
def get_net_refcounts(self, oid):
"""Like get_refcounts, but subtract out accumulated local references
Use only for testing.
"""
with self._lock:
local_rc, global_rc = self.get_refcounts(oid)
if oid in self.__obj_cache:
ref = self.__obj_cache[oid]
assert ref.refcount >= local_rc
assert ref.refcount
local_rc -= ref.refcount - 1
return local_rc, global_rc
@cached_property
def id(self):
"""Connection ID"""
return self.__id
def _do_close(self):
# Neuter all the weakref callbacks: we're closing the connection
# explicitly, which by side effect releases all references
# server-side.
with self._lock:
for wr in tuple(self.__obj_cache.values()):
wr.oid = None
self.__obj_cache.clear()
super()._do_close()
def __clear_stale_obj_locked(self, ref):
assert is_lock_owned_by_current_thread(self._lock)
oid = ref.oid
if oid:
# This function can be invoked simultaneously by the weakref
# callback and by __addref_cached_obj_locked, so make sure we
# actually remove from __obj_cache only once.
ref.oid = None
del self.__obj_cache[oid]
self.release_resource(oid, ref.refcount)
@staticmethod
def __on_cached_object_collected(ref):
# pylint: disable=protected-access
try:
connection_wr = ref.connection_wr
if connection_wr:
connection = connection_wr()
if connection:
with connection._lock:
connection.__clear_stale_obj_locked(ref)
except:
die_due_to_fatal_exception("destroying orphaned object")
def __addref_cached_obj_locked(self, oid):
ref = self.__obj_cache.get(oid)
if ref is None:
obj = None
else:
obj = ref()
if obj is None:
assert ref.connection_wr() is self
self.__clear_stale_obj_locked(ref)
else:
assert ref.refcount >= 1
ref.refcount += 1
assert obj is None or (isinstance(obj, SharedObject) and obj.oid == oid)
return obj
def internal_cache_newborn(self, obj_type, obj_dict):
"""Called during unpickling of freshly-shipped resources.
Build a new instance of OBJ_TYPE from OBJ_DICT, but without going
through the object's constructor. Cache the new object locally by
its OID and return the new object.
"""
with self._lock:
assert issubclass(obj_type, SharedObject)
obj = object.__new__(obj_type)
obj.__dict__.update(obj_dict)
oid = obj.oid
# Sometimes, we can be shipped a copy of an object we already
# have. In this case, absorb the new reference, discard the
# object we were just shipped, and return the canonical copy of
# the object for this process.
existing_obj = self.__addref_cached_obj_locked(obj.oid)
if existing_obj:
return existing_obj
assert obj.oid not in self.__obj_cache
callback = ClientConnectionBase.__on_cached_object_collected
self.__obj_cache[oid] = ObjCacheWeakRef(obj, callback, self, oid)
# pylint: disable=protected-access
obj._resman_after_pull_hook()
return obj
def load_resource(self, oid):
"""Resolve an OID we receive over a socket connection.
Called by the pickle machinery during deserialization.
This function sinks a reference.
Return the SharedObject corresponding to OID.
"""
assert self is get_process_resmanc(), \
"only the process resmanc can own resources"
with self._lock:
obj = self.__addref_cached_obj_locked(oid)
if not obj:
# We don't know about this object, so we need to get a copy.
# adopt_resource pickles the object in such a way that
# unpickling it while reading adopt_resource's return value
# ends up calling cache_newborn, which inserts the object into
# the object cache. We just check that it's done its job.
obj = self.adopt_resource(oid)
assert obj.oid == oid
assert self.__obj_cache[oid].refcount >= 1
return obj
def _transact(self, api_nr, args, kwargs):
raise NotImplementedError("abstract")
def sync_reference_counts(self):
"""Flush any pending local reference count changes to the server"""
# getpid requires a round trip, so it flushes by side effect
self.getpid()
class DeathCallbackKey(object):
"""Token used to remember how to un-add death callback"""
def __init__(self, oid, callback):
assert _assert_valid_oid(oid)
assert callable(callback)
self.oid = oid
self.callback = callback
__slots__ = ["oid", "callback", "__weakref__"]
class ClientConnection(ClientConnectionBase):
"""Socket connection to resource manager server."""
def __init__(self, client_channel, id_, name):
ClientConnectionBase.__init__(self, id_)
self.__client_channel = the(Channel, client_channel)
self.__name = name
self.fd_numbers = client_channel.fd_numbers
self.__death_callbacks = MultiDict()
self.__notification_thread = None
self.__notification_semaphore = None
self.__notifications = None
@property
def name(self):
"""Name of this ClientConnection"""
return self.__name
def _do_close(self):
# Take the lock so that nobody can observe __client_channel being
# closed before we clear all our weak reference callbacks.
with self._lock:
if self.__notification_thread:
self.__enqueue_notification((MSG_NOTIFICATION_QUIT_THREAD, None))
# Join with lock unheld because notification thread occasionally
# wants to take the lock.
if self.__notification_thread:
self.__notification_thread.join()
with self._lock:
if self.__death_callbacks:
logc.warning("ignoring pending death callbacks")
self.__death_callbacks.clear()
self.__client_channel.close()
self.__client_channel = None
super()._do_close()
def __repr__(self):
return "ClientConnection(s={!r}, id={!r}, name={!r})".format(
self.__client_channel,
self.id,
self.name)
def _transact(self, api_nr, args, kwargs, oneway):
try:
if __debug__:
logc_tx.debug("starting transaction")
with self._lock:
msg = (MSG_COMMAND, api_nr, args, kwargs)
self.__client_channel.send(msg, ToResmanSendContext)
if __debug__:
logc_tx.debug("transaction sent request %r", msg)
if oneway:
return None
del msg, args, kwargs
msg = self.__recv_reply_locked()
if __debug__:
logc_tx.debug("transaction recv msg=%r", msg)
code, payload = msg
assert code in (MSG_REPLY_SUCCESS, MSG_REPLY_ERROR)
except:
die_due_to_fatal_exception("talking to resource manager")
if code == MSG_REPLY_ERROR:
if __debug__:
logc_tx.debug("transaction error %r", payload)
reraise_exc_info(payload)
assert False, "never reached"
if __debug__:
logc_tx.debug("transaction result %r", payload)
return payload
def __enqueue_notification(self, msg):
assert msg[0] >= MSG_NOTIFICATION_MIN, \
"not a notification message: {!r}".format(msg)
self.__notifications.append(msg) # Atomic
self.__notification_semaphore.release()
def __recv_reply_locked(self, *, block=True):
try:
while True:
try:
msg = self.__client_channel.recv(ResourceRecvContext, block=block)
except BlockingIOError:
msg = None
else:
if msg[0] >= MSG_NOTIFICATION_MIN:
# Oops. We got a notification instead of the reply we
# wanted. Let the notification thread handle the
# notification. We shouldn't be getting notifications
# unless we asked for them, so the thread and its data
# structures should be up and running now.
self.__enqueue_notification(msg)
continue # Try getting the reply we actually want
return msg
except:
die_due_to_fatal_exception("talking to resource manager")
def __recv_notifications_locked(self):
while True:
try:
msg = self.__client_channel.recv(block=False)
except BlockingIOError:
return
self.__enqueue_notification(msg)
def __extract_callbacks_for_dead_oids_locked(self, dead_oids):
# Keep lock outside the function so variable references can't leak
# outside the lock.
callbacks = []
for dead_oid in dead_oids:
for key in self.__death_callbacks.pop(dead_oid, ()):
callbacks.append(key.callback)
return callbacks
def __dispatch_notifications(self, notifications):
dead_oids = []
for code, payload in notifications:
if code != MSG_NOTIFICATION_OID_GONE:
if code == MSG_NOTIFICATION_QUIT_THREAD:
assert self.closed
return
logc.warning("unknown notification code %r", code)
continue
dead_oids.extend(payload)
assert len(set(dead_oids)) == len(dead_oids)
with self._lock:
callbacks = self.__extract_callbacks_for_dead_oids_locked(dead_oids)
callbacks.reverse()
while callbacks:
callbacks.pop()()
self.__client_channel.send((MSG_ACK_NOTIFICATION, len(notifications)))
def __run_notification_thread(self):
"""Run the actual notification listener thread"""
try:
sem = self.__notification_semaphore
notifications = self.__notifications
poll = select.poll()
for fd_number in self.__client_channel.fd_numbers:
poll.register(fd_number, select.POLLIN)
poll.register(sem.fileno())
while True:
ready_fd = [fd for fd, _ in poll.poll() if fd != sem.fileno()]
if ready_fd:
with self._lock:
self.__recv_notifications_locked()
nr_ready = sem.try_acquire_all()
if nr_ready:
self.__dispatch_notifications(notifications[:nr_ready])
del notifications[:nr_ready] # Atomic
if self.closed:
break
except:
die_due_to_fatal_exception("running notification thread")
def __ensure_notification_thread_locked(self):
"""Make sure that the notification-receiver thread is running"""
if not self.__notification_thread:
self.__notification_semaphore = WaitableSemaphore()
self.__notifications = []
self.__notification_thread = threading.Thread(
name=self.name + "/notif",
daemon=True,
target=self.__run_notification_thread)
self.__notification_thread.start()
def add_death_callback(self, oid, callback):
"""Add a callback CALLBACK to be made when object OID dies
CALLBACK is called with no arguments and with no locks held from a
dedicated thread.
If CALLBACK is given for a particular OID multiple times, it's
called that many times.
Return an opaque handle that can be passed to
remove_death_callback.
"""
key = DeathCallbackKey(oid, callback)
with self._lock:
self.__ensure_notification_thread_locked()
if self.__death_callbacks.add(oid, key):
self.link_to_death_internal(oid)
return key
def remove_death_callback(self, key):
"""Remove a callback added with add_death_callback.
Removing a callback that's not present is a no-op.
"""
assert isinstance(key, DeathCallbackKey)
with self._lock:
try:
last = self.__death_callbacks.remove(key.oid, key)
except KeyError:
return
if last:
self.unlink_from_death_internal(key.oid)
class ClientConnectionFactory(SafeClosingObject):
"""Factory for making a ClientConnection
A factory is only good for creating a single connection. You can
duplicate a factory to create many client objects for the same
connection, which we use primarily for anycast.
Unlike ClientConnection, ClientConnectionFactory can be sent between
processes. Useful for bootstrapping a new child process.
"""
def __init__(self, client_channel, id_, name=None):
assert the(int, id_) > 0 # pylint: disable=compare-to-zero
assert isinstance(client_channel, Channel)
assert isinstance(name, (type(None), str))
self.__client_channel = client_channel
self.__id = id_
self.__name = name
@property
def id(self):
"""ID of the connection we'll build"""
return self.__id
def __call__(self, subclass=ClientConnection, **kwargs):
assert issubclass(subclass, ClientConnection)
resmanc = subclass(self.__client_channel, self.__id, self.__name, **kwargs)
self.__client_channel = None
return resmanc
def _do_close(self):
if self.__client_channel:
self.__client_channel.close()
super()._do_close()
def __repr__(self):
return "ClientConnectionFactory(s={!r},id={!r})".format(
self.__client_channel,
self.__id)
class LocalClient(ClientConnectionBase):
"""Special resource manager client used in resource manager process."""
def __init__(self, server_connection):
assert isinstance(server_connection, ServerConnection)
assert server_connection.id == OID_SELF_CONNECTION
ClientConnectionBase.__init__(self, server_connection.id)
self.__server_connection = server_connection
@property
def server_connection(self):
"""The server connection corresponding to this client"""
return self.__server_connection
def __local_copy(self, obj):
# We pickle and immediately unpickle so that we correctly
# translate local and remote objects, cache references, and do
# other work that happens during pickling and
# unpickling. Hopefully, we see only small messages here and the
# process isn't _that_ inefficient. If the resource manager is
# sending 20MB numpy arrays to itself, something is
# probably wrong.
# TODO(dancol): check for simple cases where a round-trip through
# pickle won't change anything and avoid the work. Also somehow
# avoid the file descriptor dups.
pickle_data, pc = fancy_pickle(obj, ToResmanSendContext)
pc.fdhs = [fdh.dup() for fdh in pc.fdhs]
return fancy_unpickle(pickle_data,
pc.fdhs,
ResmanClientRecvContext.from_(self.server_connection))
def _transact(self, api_nr, args, kwargs, _oneway):
assert get_process_resmanc() is self
assert threading.current_thread() is tls.resman_server
try:
msg = self.__local_copy((MSG_COMMAND, api_nr, args, kwargs))
del args, kwargs
# pylint: disable=protected-access
reply = self.__local_copy(self.__server_connection._invoke_api(msg))
if not reply:
return None
del msg
except:
die_due_to_fatal_exception("internal resource manager operation")
code, payload = reply
assert code in (MSG_REPLY_SUCCESS, MSG_REPLY_ERROR)
if code == MSG_REPLY_ERROR:
try:
reraise_exc_info(payload)
finally:
del code, payload # Break reference cycle
return payload
def fork(self, fn, process_resmanc_factory):
"""Fork and call FN in the child.
PROCESS_RESMANC_FACTORY makes the process resman client in the
child. It's a separate parameter so callers can construct it in
advance, wire up an apartment, and then fork.
In the child, the brains are sucked out of the current resource
manager server and control returned to toplevel before execution
of the fork child function proceeds.
Return the PID of the child.
"""
return self.__server_connection.fork(fn, process_resmanc_factory)
class SharedObjectMeta(type):
"""Metaclass that inhibits __init__ when an object is known"""
# __init__ runs too late to change the __init__ slot
def __new__(mcs, name, bases, dict_):
old_init = dict_.get("__init__")
if old_init:
# pylint: disable=missing-docstring
@wraps(old_init)
def new_init(self, *args, **kwargs):
if not self.oid:
old_init(self, *args, **kwargs)
dict_["__init__"] = new_init
return type.__new__(mcs, name, bases, dict_)
class SharedObject(object, metaclass=SharedObjectMeta):
"""Magic base class that makes a class shareable.
Just inherit from this class and storing objects of any subtype
should Just Work. Specifically, plain construction of subtypes will
create a resource in the process connection.
Subclass __init__ will be called only in the server process.
Calling SharedObject.__init__ from subclass constructors is
not necessary.
Subclasses should not override __new__; object.__new__ will be used
to construct instances of the subtype in various circumstances.
"""
# N.B. This class is either trivial or highly magical, depending on
# your perspective. It basically hijacks __new__ to turn object
# instantiation into 1) transparent instantiation on the resource
# manager server, followed by 2) a retrieval and de-pickling of the
# resulting value.
#
# We need to use a metaclass to rename __init__: if we didn't, the
# runtime would re-run it after our fake __new__, even client-side,
# resulting in double initialization.
#
# See test_object_init_run_once.
oid = None
"""Persistent cross-process ID for this object"""
def __new__(cls, *args, **kwargs):
if not tls.unpickle_context:
return get_process_resmanc().make_resource(cls, args, kwargs)
return object.__new__(cls)
def __reduce__(self):
assert isinstance(tls.pickle_context, AbstractResourceReducer)
return tls.pickle_context.reduce_resource(self)
def _resman_destroy(self):
"""Function called for resource cleanup in resman"""
# Override me
def _resman_explicit_close(self):
"""Function called in response to explicit destroy requence"""
# Override me
def _resman_after_pull_hook(self):
"""Called just after a client obtains an object"""
# It's kind of hacky to add process-management functions to the base
# interface, but otherwise we'd have to teach shm.py about
# process.py, creating a layering cycle.
def _resman_get_process_status(self):
"""Return the exit status for an object"""
raise NotImplementedError("not a process")
_resman_send_eager = False
"""Send object as copy, not reference
Sending object contents make each send more expensive, but frees the
receiver from having to call adopt_resource if it doesn't have the
object already. Setting _resman_send_eager to true is useful for
small objects that don't own file descriptors.
"""
def cache_object_until_death(obj, oid):
"""Cache SharedObject OBJ until OID dies"""
assert isinstance(obj, SharedObject)
assert _assert_valid_oid(oid)
assert obj.oid != oid
resmanc = get_process_resmanc()
if not isinstance(resmanc, LocalClient):
res = [obj]
def _cache_reclaim_callback():
res.clear()
resmanc.add_death_callback(oid, _cache_reclaim_callback)