blob: e7c3ce76c8dc30b58d84c47790a92033dad92d61 [file] [log] [blame]
# Copyright (C) 2020 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:#www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Support code for pickling and unpickling"""
import io
from socket import socket
import logging
from copyreg import dispatch_table as _copyreg_dispatch_table
from pickle import Pickler, Unpickler, HIGHEST_PROTOCOL, dumps
from .util import tls, FdHolder, assert_seq_type
log = logging.getLogger(__name__)
class ReductionError(Exception):
"""Error for pickling"""
class MessageTooLargeError(ReductionError):
"""Error indicating that a message was too large to send"""
_dispatch_table = {} # pylint: disable=invalid-name
def __newobj_ex__(cls, *args, **kwargs):
"""Never called: activates special case in pickle"""
return cls.__new__(cls, *args, **kwargs)
def __newobj__(cls, *args):
"""Never called: activates special case in pickle"""
return cls.__new__(cls, *args)
def register(type_, reducer):
"""Register a reducer for modernmp IPC
TYPE_ is a type object; REDUCER is a function like __reduce__.
"""
assert isinstance(type_, type)
_dispatch_table[type_] = reducer
def payload_is_in_fd_marker():
"""Indicates that the real payload is in a shared memory segment"""
# We use a function as our marker object because functions are
# unpickled by name, preserving the singleton-ness of the marker.
assert False, "this is a marker function and should never be called"
class PickleContext(object):
"""Thread-local context for pickling"""
def __init__(self):
self.fdhs = []
self.__seen_fd_numbers = set()
def save_fd(self, fdh):
"""Remember to send a file descriptor with the current message
FDH must be an FdHolder. Return an opaque value that must be
pickled and then given to UnpickleContext.get_fd().
"""
assert isinstance(fdh, FdHolder)
# In certain rare cases involving borrowed FdHolder instances, we
# might try to save the same numeric FD more than once. To avoid
# confusing persistence code, just dup these.
if fdh.fileno() in self.__seen_fd_numbers:
fdh = fdh.dup()
assert fdh.fileno() not in self.__seen_fd_numbers
key = len(self.fdhs)
if __debug__:
key = (key, fdh.fdid) # pylint: disable=redefined-variable-type
self.fdhs.append(fdh)
self.__seen_fd_numbers.add(fdh.fileno())
return key
make_shared_memory_fd = None
"""Function to make a shared memory segment"""
class UnpickleContext(object):
"""Thread-local context for unpickling"""
def __init__(self, *, fdhs):
if fdhs:
assert assert_seq_type(list, FdHolder, fdhs)
self.fdhs = fdhs
def get_fd(self, key):
"""Retrieve a file descriptor sent with the current message.
KEY is the return value from PickleContext.save_fd(). Return an
FdHolder. The caller owns the FdHolder and may close, release, or
store it as needed.
"""
if __debug__:
index, saved_fdid = key
assert self.fdhs[index].fdid == saved_fdid
else:
index = key
fdh = self.fdhs[index]
self.fdhs[index] = None
return fdh
def fancy_pickle(obj,
make_pickle_context=PickleContext,
size_limit=None):
"""Pickle object OBJ using our internal pickle registration.
SIZE_LIMIT is the maximum size of the pickled data to bundle.
If the payload is larger than SIZE_LIMIT, then either fail by
raising MessageTooLargeError (if the pickle context's
make_shared_memory_fd is None) or use the pickle context's
make_shared_memory_fd to make a file into which we place the large
pickle payload and which we pickle in lieu of the raw pickle file.
Return a tuple (PICKLE_DATA, PICKLE_CONTEXT), where PICKLE_DATA is a
byte buffer of some sort containing the actual pickle data, and
PICKLE_CONTEXT is the context object we created using the zero-arity
MAKE_PICKLE_CONTEXT function. Notably, PICKLE_CONTEXT.fdhs is a
sequence of file descriptor holders to send along with the
pickle data.
"""
# TODO(dancol): provide a way to reuse buffers?
out_file = io.BytesIO()
# Need to copy the global dispatch table in case it's changed.
# We override it with our registered picklers in any case.
pickler = Pickler(out_file, HIGHEST_PROTOCOL)
pickler.dispatch_table = _copyreg_dispatch_table.copy()
pickler.dispatch_table.update(_dispatch_table)
with tls.set_pickle_context(make_pickle_context()) as pc:
assert isinstance(pc, PickleContext)
pickler.dump(obj)
pickle_data = out_file.getbuffer()
too_big = size_limit and len(pickle_data) > size_limit
if obj is payload_is_in_fd_marker or too_big:
if not pc.make_shared_memory_fd:
if too_big:
raise MessageTooLargeError(
"pickled data is {!r} bytes long; max is {!r}".format(
len(pickle_data), size_limit))
else:
raise ReductionError(
"cannot pickle payload without memoy file support")
jumbo_fdh = pc.make_shared_memory_fd()
assert isinstance(jumbo_fdh, FdHolder)
jumbo_file = jumbo_fdh.as_file("r+b", steal=True)
jumbo_fdh = FdHolder.borrow(jumbo_file)
jumbo_file.write(pickle_data)
jumbo_file.seek(0)
pickle_data = dumps(payload_is_in_fd_marker, HIGHEST_PROTOCOL)
if not pc.fdhs:
pc.fdhs = []
pc.fdhs.append(jumbo_fdh)
return pickle_data, pc
def fancy_unpickle(pickle_data,
fdhs,
make_unpickle_context=UnpickleContext):
"""Inverse of fancy_pickle
PICKLE_DATA is the bundle that fancy_pickle produced; FDHS is a list
of file descriptors pointing at the same open file descriptions that
fancy_pickle produced. This list is destructively modified so we
can close FDs early when possible.
MAKE_UNPICKLE_CONTEXT is a function of one argument (which is just
FDHS) that builds an UnpickleContext for the unpickle operation.
Return the unpickled object.
"""
with tls.set_unpickle_context(make_unpickle_context(fdhs=fdhs)) as uc:
assert isinstance(uc, UnpickleContext)
in_file = io.BytesIO(pickle_data)
unpickler = Unpickler(in_file)
obj = unpickler.load()
if obj is payload_is_in_fd_marker:
in_file = fdhs.pop().as_file("rb", steal=True)
tls.unpickle_context = uc = make_unpickle_context(fdhs=fdhs)
unpickler = Unpickler(in_file)
obj = unpickler.load()
assert not any(uc.fdhs), "all FDs should be claimed"
return obj
def _rebuild_socket(state):
family, type_, proto, fd_cookie = state
fd_holder = tls.unpickle_context.get_fd(fd_cookie)
sock = socket(family, type_, proto, fd_holder.fileno())
fd_holder.detach()
return sock
def _reduce_socket(sock):
state = (sock.family, sock.type, sock.proto,
tls.pickle_context.save_fd(FdHolder.borrow(sock)))
return _rebuild_socket, (state,)
register(socket, _reduce_socket)
def _rebuild_fdholder(state):
return tls.unpickle_context.get_fd(state)
def _reduce_fdholder(fd):
state = tls.pickle_context.save_fd(fd)
return _rebuild_fdholder, (state,)
register(FdHolder, _reduce_fdholder)