| """isort:skip_file""" |
| from pickle import EXT1, EXT2, EXT4, GLOBAL, STACK_GLOBAL, Pickler, PicklingError |
| from pickle import _compat_pickle, _extension_registry, _getattribute, _Pickler # type: ignore |
| from struct import pack |
| from types import FunctionType |
| |
| from .importer import Importer, ObjMismatchError, ObjNotFoundError, sys_importer |
| |
| |
| class PackagePickler(_Pickler): |
| """Package-aware pickler. |
| |
| This behaves the same as a normal pickler, except it uses an `Importer` |
| to find objects and modules to save. |
| """ |
| |
| dispatch = _Pickler.dispatch.copy() |
| |
| def __init__(self, importer: Importer, *args, **kwargs): |
| self.importer = importer |
| super().__init__(*args, **kwargs) |
| |
| def save_global(self, obj, name=None): |
| # unfortunately the pickler code is factored in a way that |
| # forces us to copy/paste this function. The only change is marked |
| # CHANGED below. |
| write = self.write |
| memo = self.memo |
| |
| # CHANGED: import module from module environment instead of __import__ |
| try: |
| module_name, name = self.importer.get_name(obj, name) |
| except (ObjNotFoundError, ObjMismatchError) as err: |
| raise PicklingError(f"Can't pickle {obj}: {str(err)}") from None |
| |
| module = self.importer.import_module(module_name) |
| _, parent = _getattribute(module, name) |
| # END CHANGED |
| |
| if self.proto >= 2: |
| code = _extension_registry.get((module_name, name)) |
| if code: |
| assert code > 0 |
| if code <= 0xFF: |
| write(EXT1 + pack("<B", code)) |
| elif code <= 0xFFFF: |
| write(EXT2 + pack("<H", code)) |
| else: |
| write(EXT4 + pack("<i", code)) |
| return |
| lastname = name.rpartition(".")[2] |
| if parent is module: |
| name = lastname |
| # Non-ASCII identifiers are supported only with protocols >= 3. |
| if self.proto >= 4: |
| self.save(module_name) |
| self.save(name) |
| write(STACK_GLOBAL) |
| elif parent is not module: |
| self.save_reduce(getattr, (parent, lastname)) |
| elif self.proto >= 3: |
| write( |
| GLOBAL |
| + bytes(module_name, "utf-8") |
| + b"\n" |
| + bytes(name, "utf-8") |
| + b"\n" |
| ) |
| else: |
| if self.fix_imports: |
| r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING |
| r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING |
| if (module_name, name) in r_name_mapping: |
| module_name, name = r_name_mapping[(module_name, name)] |
| elif module_name in r_import_mapping: |
| module_name = r_import_mapping[module_name] |
| try: |
| write( |
| GLOBAL |
| + bytes(module_name, "ascii") |
| + b"\n" |
| + bytes(name, "ascii") |
| + b"\n" |
| ) |
| except UnicodeEncodeError: |
| raise PicklingError( |
| "can't pickle global identifier '%s.%s' using " |
| "pickle protocol %i" % (module, name, self.proto) |
| ) from None |
| |
| self.memoize(obj) |
| |
| dispatch[FunctionType] = save_global |
| |
| |
| def create_pickler(data_buf, importer): |
| if importer is sys_importer: |
| # if we are using the normal import library system, then |
| # we can use the C implementation of pickle which is faster |
| return Pickler(data_buf, protocol=3) |
| else: |
| return PackagePickler(importer, data_buf, protocol=3) |