| # Copyright (C) 2020 The Android Open Source Project |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http:#www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """Support code for pickling and unpickling""" |
| import io |
| from socket import socket |
| import logging |
| |
| from copyreg import dispatch_table as _copyreg_dispatch_table |
| from pickle import Pickler, Unpickler, HIGHEST_PROTOCOL, dumps |
| |
| from .util import tls, FdHolder, assert_seq_type |
| |
| log = logging.getLogger(__name__) |
| |
| class ReductionError(Exception): |
| """Error for pickling""" |
| |
| class MessageTooLargeError(ReductionError): |
| """Error indicating that a message was too large to send""" |
| |
| _dispatch_table = {} # pylint: disable=invalid-name |
| |
| def __newobj_ex__(cls, *args, **kwargs): |
| """Never called: activates special case in pickle""" |
| return cls.__new__(cls, *args, **kwargs) |
| |
| def __newobj__(cls, *args): |
| """Never called: activates special case in pickle""" |
| return cls.__new__(cls, *args) |
| |
| def register(type_, reducer): |
| """Register a reducer for modernmp IPC |
| |
| TYPE_ is a type object; REDUCER is a function like __reduce__. |
| """ |
| assert isinstance(type_, type) |
| _dispatch_table[type_] = reducer |
| |
| def payload_is_in_fd_marker(): |
| """Indicates that the real payload is in a shared memory segment""" |
| # We use a function as our marker object because functions are |
| # unpickled by name, preserving the singleton-ness of the marker. |
| assert False, "this is a marker function and should never be called" |
| |
| class PickleContext(object): |
| """Thread-local context for pickling""" |
| def __init__(self): |
| self.fdhs = [] |
| self.__seen_fd_numbers = set() |
| |
| def save_fd(self, fdh): |
| """Remember to send a file descriptor with the current message |
| |
| FDH must be an FdHolder. Return an opaque value that must be |
| pickled and then given to UnpickleContext.get_fd(). |
| """ |
| assert isinstance(fdh, FdHolder) |
| # In certain rare cases involving borrowed FdHolder instances, we |
| # might try to save the same numeric FD more than once. To avoid |
| # confusing persistence code, just dup these. |
| if fdh.fileno() in self.__seen_fd_numbers: |
| fdh = fdh.dup() |
| assert fdh.fileno() not in self.__seen_fd_numbers |
| key = len(self.fdhs) |
| if __debug__: |
| key = (key, fdh.fdid) # pylint: disable=redefined-variable-type |
| self.fdhs.append(fdh) |
| self.__seen_fd_numbers.add(fdh.fileno()) |
| return key |
| |
| make_shared_memory_fd = None |
| """Function to make a shared memory segment""" |
| |
| class UnpickleContext(object): |
| """Thread-local context for unpickling""" |
| def __init__(self, *, fdhs): |
| if fdhs: |
| assert assert_seq_type(list, FdHolder, fdhs) |
| self.fdhs = fdhs |
| |
| def get_fd(self, key): |
| """Retrieve a file descriptor sent with the current message. |
| |
| KEY is the return value from PickleContext.save_fd(). Return an |
| FdHolder. The caller owns the FdHolder and may close, release, or |
| store it as needed. |
| """ |
| if __debug__: |
| index, saved_fdid = key |
| assert self.fdhs[index].fdid == saved_fdid |
| else: |
| index = key |
| fdh = self.fdhs[index] |
| self.fdhs[index] = None |
| return fdh |
| |
| def fancy_pickle(obj, |
| make_pickle_context=PickleContext, |
| size_limit=None): |
| """Pickle object OBJ using our internal pickle registration. |
| |
| SIZE_LIMIT is the maximum size of the pickled data to bundle. |
| If the payload is larger than SIZE_LIMIT, then either fail by |
| raising MessageTooLargeError (if the pickle context's |
| make_shared_memory_fd is None) or use the pickle context's |
| make_shared_memory_fd to make a file into which we place the large |
| pickle payload and which we pickle in lieu of the raw pickle file. |
| |
| Return a tuple (PICKLE_DATA, PICKLE_CONTEXT), where PICKLE_DATA is a |
| byte buffer of some sort containing the actual pickle data, and |
| PICKLE_CONTEXT is the context object we created using the zero-arity |
| MAKE_PICKLE_CONTEXT function. Notably, PICKLE_CONTEXT.fdhs is a |
| sequence of file descriptor holders to send along with the |
| pickle data. |
| """ |
| # TODO(dancol): provide a way to reuse buffers? |
| out_file = io.BytesIO() |
| # Need to copy the global dispatch table in case it's changed. |
| # We override it with our registered picklers in any case. |
| pickler = Pickler(out_file, HIGHEST_PROTOCOL) |
| pickler.dispatch_table = _copyreg_dispatch_table.copy() |
| pickler.dispatch_table.update(_dispatch_table) |
| with tls.set_pickle_context(make_pickle_context()) as pc: |
| assert isinstance(pc, PickleContext) |
| pickler.dump(obj) |
| pickle_data = out_file.getbuffer() |
| too_big = size_limit and len(pickle_data) > size_limit |
| if obj is payload_is_in_fd_marker or too_big: |
| if not pc.make_shared_memory_fd: |
| if too_big: |
| raise MessageTooLargeError( |
| "pickled data is {!r} bytes long; max is {!r}".format( |
| len(pickle_data), size_limit)) |
| else: |
| raise ReductionError( |
| "cannot pickle payload without memoy file support") |
| jumbo_fdh = pc.make_shared_memory_fd() |
| assert isinstance(jumbo_fdh, FdHolder) |
| jumbo_file = jumbo_fdh.as_file("r+b", steal=True) |
| jumbo_fdh = FdHolder.borrow(jumbo_file) |
| jumbo_file.write(pickle_data) |
| jumbo_file.seek(0) |
| pickle_data = dumps(payload_is_in_fd_marker, HIGHEST_PROTOCOL) |
| if not pc.fdhs: |
| pc.fdhs = [] |
| pc.fdhs.append(jumbo_fdh) |
| return pickle_data, pc |
| |
| def fancy_unpickle(pickle_data, |
| fdhs, |
| make_unpickle_context=UnpickleContext): |
| """Inverse of fancy_pickle |
| |
| PICKLE_DATA is the bundle that fancy_pickle produced; FDHS is a list |
| of file descriptors pointing at the same open file descriptions that |
| fancy_pickle produced. This list is destructively modified so we |
| can close FDs early when possible. |
| |
| MAKE_UNPICKLE_CONTEXT is a function of one argument (which is just |
| FDHS) that builds an UnpickleContext for the unpickle operation. |
| |
| Return the unpickled object. |
| """ |
| with tls.set_unpickle_context(make_unpickle_context(fdhs=fdhs)) as uc: |
| assert isinstance(uc, UnpickleContext) |
| in_file = io.BytesIO(pickle_data) |
| unpickler = Unpickler(in_file) |
| obj = unpickler.load() |
| if obj is payload_is_in_fd_marker: |
| in_file = fdhs.pop().as_file("rb", steal=True) |
| tls.unpickle_context = uc = make_unpickle_context(fdhs=fdhs) |
| unpickler = Unpickler(in_file) |
| obj = unpickler.load() |
| assert not any(uc.fdhs), "all FDs should be claimed" |
| return obj |
| |
| def _rebuild_socket(state): |
| family, type_, proto, fd_cookie = state |
| fd_holder = tls.unpickle_context.get_fd(fd_cookie) |
| sock = socket(family, type_, proto, fd_holder.fileno()) |
| fd_holder.detach() |
| return sock |
| |
| def _reduce_socket(sock): |
| state = (sock.family, sock.type, sock.proto, |
| tls.pickle_context.save_fd(FdHolder.borrow(sock))) |
| return _rebuild_socket, (state,) |
| |
| register(socket, _reduce_socket) |
| |
| def _rebuild_fdholder(state): |
| return tls.unpickle_context.get_fd(state) |
| |
| def _reduce_fdholder(fd): |
| state = tls.pickle_context.save_fd(fd) |
| return _rebuild_fdholder, (state,) |
| |
| register(FdHolder, _reduce_fdholder) |