blob: 7d629b4e2380dccf649f78b6a633742937384abf [file] [log] [blame]
"""isort:skip_file"""
from pickle import EXT1, EXT2, EXT4, GLOBAL, STACK_GLOBAL, Pickler, PicklingError
from pickle import _compat_pickle, _extension_registry, _getattribute, _Pickler # type: ignore
from struct import pack
from types import FunctionType
from .importer import Importer, ObjMismatchError, ObjNotFoundError, sys_importer
class PackagePickler(_Pickler):
"""Package-aware pickler.
This behaves the same as a normal pickler, except it uses an `Importer`
to find objects and modules to save.
"""
dispatch = _Pickler.dispatch.copy()
def __init__(self, importer: Importer, *args, **kwargs):
self.importer = importer
super().__init__(*args, **kwargs)
def save_global(self, obj, name=None):
# unfortunately the pickler code is factored in a way that
# forces us to copy/paste this function. The only change is marked
# CHANGED below.
write = self.write
memo = self.memo
# CHANGED: import module from module environment instead of __import__
try:
module_name, name = self.importer.get_name(obj, name)
except (ObjNotFoundError, ObjMismatchError) as err:
raise PicklingError(f"Can't pickle {obj}: {str(err)}") from None
module = self.importer.import_module(module_name)
_, parent = _getattribute(module, name)
# END CHANGED
if self.proto >= 2:
code = _extension_registry.get((module_name, name))
if code:
assert code > 0
if code <= 0xFF:
write(EXT1 + pack("<B", code))
elif code <= 0xFFFF:
write(EXT2 + pack("<H", code))
else:
write(EXT4 + pack("<i", code))
return
lastname = name.rpartition(".")[2]
if parent is module:
name = lastname
# Non-ASCII identifiers are supported only with protocols >= 3.
if self.proto >= 4:
self.save(module_name)
self.save(name)
write(STACK_GLOBAL)
elif parent is not module:
self.save_reduce(getattr, (parent, lastname))
elif self.proto >= 3:
write(
GLOBAL
+ bytes(module_name, "utf-8")
+ b"\n"
+ bytes(name, "utf-8")
+ b"\n"
)
else:
if self.fix_imports:
r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING
r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING
if (module_name, name) in r_name_mapping:
module_name, name = r_name_mapping[(module_name, name)]
elif module_name in r_import_mapping:
module_name = r_import_mapping[module_name]
try:
write(
GLOBAL
+ bytes(module_name, "ascii")
+ b"\n"
+ bytes(name, "ascii")
+ b"\n"
)
except UnicodeEncodeError:
raise PicklingError(
"can't pickle global identifier '%s.%s' using "
"pickle protocol %i" % (module, name, self.proto)
) from None
self.memoize(obj)
dispatch[FunctionType] = save_global
def create_pickler(data_buf, importer):
if importer is sys_importer:
# if we are using the normal import library system, then
# we can use the C implementation of pickle which is faster
return Pickler(data_buf, protocol=3)
else:
return PackagePickler(importer, data_buf, protocol=3)