| """isort:skip_file""" |
| from pickle import ( # type: ignore[attr-defined] |
| _compat_pickle, |
| _extension_registry, |
| _getattribute, |
| _Pickler, |
| EXT1, |
| EXT2, |
| EXT4, |
| GLOBAL, |
| Pickler, |
| PicklingError, |
| STACK_GLOBAL, |
| ) |
| from struct import pack |
| from types import FunctionType |
| |
| from .importer import Importer, ObjMismatchError, ObjNotFoundError, sys_importer |
| |
| |
| class PackagePickler(_Pickler): |
| """Package-aware pickler. |
| |
| This behaves the same as a normal pickler, except it uses an `Importer` |
| to find objects and modules to save. |
| """ |
| |
| def __init__(self, importer: Importer, *args, **kwargs): |
| self.importer = importer |
| super().__init__(*args, **kwargs) |
| |
| # Make sure the dispatch table copied from _Pickler is up-to-date. |
| # Previous issues have been encountered where a library (e.g. dill) |
| # mutate _Pickler.dispatch, PackagePickler makes a copy when this lib |
| # is imported, then the offending library removes its dispatch entries, |
| # leaving PackagePickler with a stale dispatch table that may cause |
| # unwanted behavior. |
| self.dispatch = _Pickler.dispatch.copy() # type: ignore[misc] |
| self.dispatch[FunctionType] = PackagePickler.save_global # type: ignore[assignment] |
| |
| def save_global(self, obj, name=None): |
| # unfortunately the pickler code is factored in a way that |
| # forces us to copy/paste this function. The only change is marked |
| # CHANGED below. |
| write = self.write # type: ignore[attr-defined] |
| memo = self.memo # type: ignore[attr-defined] |
| |
| # CHANGED: import module from module environment instead of __import__ |
| try: |
| module_name, name = self.importer.get_name(obj, name) |
| except (ObjNotFoundError, ObjMismatchError) as err: |
| raise PicklingError(f"Can't pickle {obj}: {str(err)}") from None |
| |
| module = self.importer.import_module(module_name) |
| _, parent = _getattribute(module, name) |
| # END CHANGED |
| |
| if self.proto >= 2: # type: ignore[attr-defined] |
| code = _extension_registry.get((module_name, name)) |
| if code: |
| assert code > 0 |
| if code <= 0xFF: |
| write(EXT1 + pack("<B", code)) |
| elif code <= 0xFFFF: |
| write(EXT2 + pack("<H", code)) |
| else: |
| write(EXT4 + pack("<i", code)) |
| return |
| lastname = name.rpartition(".")[2] |
| if parent is module: |
| name = lastname |
| # Non-ASCII identifiers are supported only with protocols >= 3. |
| if self.proto >= 4: # type: ignore[attr-defined] |
| self.save(module_name) # type: ignore[attr-defined] |
| self.save(name) # type: ignore[attr-defined] |
| write(STACK_GLOBAL) |
| elif parent is not module: |
| self.save_reduce(getattr, (parent, lastname)) # type: ignore[attr-defined] |
| elif self.proto >= 3: # type: ignore[attr-defined] |
| write( |
| GLOBAL |
| + bytes(module_name, "utf-8") |
| + b"\n" |
| + bytes(name, "utf-8") |
| + b"\n" |
| ) |
| else: |
| if self.fix_imports: # type: ignore[attr-defined] |
| r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING |
| r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING |
| if (module_name, name) in r_name_mapping: |
| module_name, name = r_name_mapping[(module_name, name)] |
| elif module_name in r_import_mapping: |
| module_name = r_import_mapping[module_name] |
| try: |
| write( |
| GLOBAL |
| + bytes(module_name, "ascii") |
| + b"\n" |
| + bytes(name, "ascii") |
| + b"\n" |
| ) |
| except UnicodeEncodeError: |
| raise PicklingError( |
| "can't pickle global identifier '%s.%s' using " |
| "pickle protocol %i" % (module, name, self.proto) # type: ignore[attr-defined] |
| ) from None |
| |
| self.memoize(obj) # type: ignore[attr-defined] |
| |
| |
| def create_pickler(data_buf, importer, protocol=4): |
| if importer is sys_importer: |
| # if we are using the normal import library system, then |
| # we can use the C implementation of pickle which is faster |
| return Pickler(data_buf, protocol=protocol) |
| else: |
| return PackagePickler(importer, data_buf, protocol=protocol) |