| import difflib | 
 | import os | 
 | import io | 
 | import shutil | 
 | import struct | 
 | import sys | 
 | import torch | 
 | import tarfile | 
 | import tempfile | 
 | import warnings | 
 | from contextlib import closing, contextmanager | 
 | from enum import Enum | 
 | from ._utils import _import_dotted_name | 
 | from torch._sources import get_source_lines_and_file | 
 | from torch.types import Storage | 
 | from torch.storage import _get_dtype_from_pickle_storage_type | 
 | from typing import Any, BinaryIO, Callable, cast, Dict, Optional, Type, Tuple, Union, IO | 
 | from typing_extensions import TypeAlias  # Python 3.10+ | 
 | import copyreg | 
 | import pickle | 
 | import pathlib | 
 | import torch._weights_only_unpickler as _weights_only_unpickler | 
 |  | 
 | DEFAULT_PROTOCOL = 2 | 
 |  | 
 | LONG_SIZE = struct.Struct('=l').size | 
 | INT_SIZE = struct.Struct('=i').size | 
 | SHORT_SIZE = struct.Struct('=h').size | 
 |  | 
 | MAGIC_NUMBER = 0x1950a86a20f9469cfc6c | 
 | PROTOCOL_VERSION = 1001 | 
 | STORAGE_KEY_SEPARATOR = ',' | 
 |  | 
 | FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]] | 
 | MAP_LOCATION: TypeAlias = Optional[Union[Callable[[torch.Tensor, str], torch.Tensor], torch.device, str, Dict[str, str]]] | 
 | STORAGE: TypeAlias = Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage] | 
 |  | 
 | __all__ = [ | 
 |     'SourceChangeWarning', | 
 |     'mkdtemp', | 
 |     'register_package', | 
 |     'check_module_version_greater_or_equal', | 
 |     'validate_cuda_device', | 
 |     'validate_hpu_device', | 
 |     'location_tag', | 
 |     'default_restore_location', | 
 |     'normalize_storage_type', | 
 |     'storage_to_tensor_type', | 
 |     'save', | 
 |     'load', | 
 |     'StorageType', | 
 |     'LoadEndianness', | 
 |     'get_default_load_endianness', | 
 |     'set_default_load_endianness', | 
 | ] | 
 |  | 
 |  | 
 | class SourceChangeWarning(Warning): | 
 |     pass | 
 |  | 
 |  | 
 | @contextmanager | 
 | def mkdtemp(): | 
 |     path = tempfile.mkdtemp() | 
 |     try: | 
 |         yield path | 
 |     finally: | 
 |         shutil.rmtree(path) | 
 |  | 
 |  | 
 | _package_registry = [] | 
 |  | 
 | class LoadEndianness(Enum): | 
 |     NATIVE = 1 | 
 |     LITTLE = 2 | 
 |     BIG = 3 | 
 |  | 
 | _default_load_endian: Optional[LoadEndianness] = None | 
 |  | 
 | def get_default_load_endianness() -> Optional[LoadEndianness]: | 
 |     ''' | 
 |     Get fallback byte order for loading files | 
 |  | 
 |     If byteorder mark is not present in saved checkpoint, | 
 |     this byte order is used as fallback. | 
 |     By default, it's "native" byte order. | 
 |  | 
 |     Returns: | 
 |         default_load_endian: Optional[LoadEndianness] | 
 |     ''' | 
 |     return _default_load_endian | 
 |  | 
 | def set_default_load_endianness(endianness): | 
 |     ''' | 
 |     Set fallback byte order for loading files | 
 |  | 
 |     If byteorder mark is not present in saved checkpoint, | 
 |     this byte order is used as fallback. | 
 |     By default, it's "native" byte order. | 
 |  | 
 |     Args: | 
 |         endianness: the new fallback byte order | 
 |     ''' | 
 |     global _default_load_endian | 
 |     if not isinstance(endianness, LoadEndianness) and endianness is not None: | 
 |         raise TypeError("Invalid argument type in function set_default_load_endianness") | 
 |     _default_load_endian = endianness | 
 |  | 
 | def _is_zipfile(f) -> bool: | 
 |     # This is a stricter implementation than zipfile.is_zipfile(). | 
 |     # zipfile.is_zipfile() is True if the magic number appears anywhere in the | 
 |     # binary. Since we expect the files here to be generated by torch.save or | 
 |     # torch.jit.save, it's safe to only check the start bytes and avoid | 
 |     # collisions and assume the zip has only 1 file. | 
 |     # See bugs.python.org/issue28494. | 
 |  | 
 |     start = f.tell() | 
 |     # Read the first few bytes and match against the ZIP file signature | 
 |     local_header_magic_number = b'PK\x03\x04' | 
 |     read_bytes = f.read(len(local_header_magic_number)) | 
 |     f.seek(start) | 
 |     return read_bytes == local_header_magic_number | 
 |  | 
 |  | 
 | def register_package( | 
 |     priority: int, | 
 |     tagger: Callable[[STORAGE], Optional[str]], | 
 |     deserializer: Callable[[STORAGE, str], Optional[STORAGE]] | 
 | ): | 
 |     ''' | 
 |     Registers callables for tagging and deserializing storage objects with an associated priority. | 
 |     Tagging associates a device with a storage object at save time while deserializing moves a | 
 |     storage object to an appropriate device at load time. :attr:`tagger` and :attr:`deserializer` | 
 |     are run in the order given by their :attr:`priority` until a tagger/deserializer returns a | 
 |     value that is not `None`. | 
 |  | 
 |     To override the deserialization behavior for a device in the global registry, one can register a | 
 |     tagger with a higher priority than the existing tagger. | 
 |  | 
 |     This function can also be used to register a tagger and deserializer for new devices. | 
 |  | 
 |     Args: | 
 |         priority: Indicates the priority associated with the tagger and deserializer, where a lower | 
 |             value indicates higher priority. | 
 |         tagger: Callable that takes in a storage object and returns its tagged device as a string | 
 |             or None. | 
 |         deserializer: Callable that takes in storage object and a device string and returns a storage | 
 |             object on the appropriate device or None. | 
 |  | 
 |     Returns: | 
 |         `None` | 
 |  | 
 |     Example: | 
 |         >>> def ipu_tag(obj): | 
 |         >>>     if obj.device.type == 'ipu': | 
 |         >>>         return 'ipu' | 
 |         >>> def ipu_deserialize(obj, location): | 
 |         >>>     if location.startswith('ipu'): | 
 |         >>>         ipu = getattr(torch, "ipu", None) | 
 |         >>>         assert ipu is not None, "IPU device module is not loaded" | 
 |         >>>         assert torch.ipu.is_available(), "ipu is not available" | 
 |         >>>         return obj.ipu(location) | 
 |         >>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize) | 
 |     ''' | 
 |     queue_elem = (priority, tagger, deserializer) | 
 |     _package_registry.append(queue_elem) | 
 |     _package_registry.sort() | 
 |  | 
 |  | 
 | def check_module_version_greater_or_equal(module, req_version_tuple, error_if_malformed=True): | 
 |     ''' | 
 |     Check if a module's version satisfies requirements | 
 |  | 
 |     Usually, a module's version string will be like 'x.y.z', which would be represented | 
 |     as a tuple (x, y, z), but sometimes it could be an unexpected format. If the version | 
 |     string does not match the given tuple's format up to the length of the tuple, then | 
 |     error and exit or emit a warning. | 
 |  | 
 |     Args: | 
 |         module: the module to check the version of | 
 |         req_version_tuple: tuple (usually of ints) representing the required version | 
 |         error_if_malformed: whether we should exit if module version string is malformed | 
 |  | 
 |     Returns: | 
 |         requirement_is_met: bool | 
 |     ''' | 
 |     try: | 
 |         version_strs = module.__version__.split('.') | 
 |         # Cast module version fields to match the types of the required version | 
 |         module_version = tuple( | 
 |             type(req_field)(version_strs[idx]) for idx, req_field in enumerate(req_version_tuple) | 
 |         ) | 
 |         requirement_is_met = module_version >= req_version_tuple | 
 |  | 
 |     except Exception as e: | 
 |         message = ( | 
 |             f"'{module.__name__}' module version string is malformed '{module.__version__}' and cannot be compared" | 
 |             f" with tuple {str(req_version_tuple)}" | 
 |         ) | 
 |         if error_if_malformed: | 
 |             raise RuntimeError(message) from e | 
 |         else: | 
 |             warnings.warn(message + ', but continuing assuming that requirement is met') | 
 |             requirement_is_met = True | 
 |  | 
 |     return requirement_is_met | 
 |  | 
 |  | 
 | def _cpu_tag(obj): | 
 |     if obj.device.type == 'cpu': | 
 |         return 'cpu' | 
 |  | 
 |  | 
 | def _cuda_tag(obj): | 
 |     if obj.device.type == 'cuda': | 
 |         return 'cuda:' + str(obj.device.index) | 
 |  | 
 | def _hpu_tag(obj): | 
 |     if obj.device.type == 'hpu': | 
 |         return 'hpu:' + str(obj.device.index) | 
 |  | 
 | def _mps_tag(obj): | 
 |     if obj.device.type == 'mps': | 
 |         return 'mps' | 
 |  | 
 |  | 
 | def _meta_tag(obj): | 
 |     if obj.device.type == 'meta': | 
 |         return 'meta' | 
 |  | 
 |  | 
 | def _privateuse1_tag(obj): | 
 |     backend_name = torch._C._get_privateuse1_backend_name() | 
 |     if obj.device.type == backend_name: | 
 |         if obj.device.index is None: | 
 |             return backend_name | 
 |         else: | 
 |             return backend_name + ':' + str(obj.device.index) | 
 |  | 
 |  | 
 | def _cpu_deserialize(obj, location): | 
 |     if location == 'cpu': | 
 |         return obj | 
 |  | 
 |  | 
 | def validate_cuda_device(location): | 
 |     device = torch.cuda._utils._get_device_index(location, True) | 
 |  | 
 |     if not torch.cuda.is_available(): | 
 |         raise RuntimeError('Attempting to deserialize object on a CUDA ' | 
 |                            'device but torch.cuda.is_available() is False. ' | 
 |                            'If you are running on a CPU-only machine, ' | 
 |                            'please use torch.load with map_location=torch.device(\'cpu\') ' | 
 |                            'to map your storages to the CPU.') | 
 |     device_count = torch.cuda.device_count() | 
 |     if device >= device_count: | 
 |         raise RuntimeError('Attempting to deserialize object on CUDA device ' | 
 |                            f'{device} but torch.cuda.device_count() is {device_count}. Please use ' | 
 |                            'torch.load with map_location to map your storages ' | 
 |                            'to an existing device.') | 
 |     return device | 
 |  | 
 |  | 
 | def _cuda_deserialize(obj, location): | 
 |     if location.startswith('cuda'): | 
 |         device = validate_cuda_device(location) | 
 |         if getattr(obj, "_torch_load_uninitialized", False): | 
 |             with torch.cuda.device(device): | 
 |                 return torch.UntypedStorage(obj.nbytes(), device=torch.device(location)) | 
 |         else: | 
 |             return obj.cuda(device) | 
 |  | 
 |  | 
 | def validate_hpu_device(location): | 
 |     hpu = getattr(torch, "hpu", None) | 
 |     assert hpu is not None, "HPU device module is not loaded" | 
 |     device = hpu._utils._get_device_index(location, optional=True) | 
 |  | 
 |     if not hpu.is_available(): | 
 |         raise RuntimeError('Attempting to deserialize object on a HPU ' | 
 |                            'device but torch.hpu.is_available() is False. ' | 
 |                            'If you are running on a CPU-only machine, ' | 
 |                            'please use torch.load with map_location=torch.device(\'cpu\') ' | 
 |                            'to map your storages to the CPU.') | 
 |     device_count = hpu.device_count() | 
 |     if device >= device_count: | 
 |         raise RuntimeError('Attempting to deserialize object on HPU device ' | 
 |                            f'{device} but torch.hpu.device_count() is {device_count}. Please use ' | 
 |                            'torch.load with map_location to map your storages ' | 
 |                            'to an existing device.') | 
 |     return device | 
 |  | 
 |  | 
 | def _hpu_deserialize(obj, location): | 
 |     if location.startswith('hpu'): | 
 |         hpu = getattr(torch, "hpu", None) | 
 |         assert hpu is not None, "HPU device module is not loaded" | 
 |         device = validate_hpu_device(location) | 
 |         if getattr(obj, "_torch_load_uninitialized", False): | 
 |             with hpu.device(device): | 
 |                 return torch.UntypedStorage(obj.nbytes(), device=torch.device(location)) | 
 |         else: | 
 |             return obj.hpu(device) | 
 |  | 
 |  | 
 | def _mps_deserialize(obj, location): | 
 |     if location.startswith('mps'): | 
 |         return obj.mps() | 
 |  | 
 |  | 
 | def _meta_deserialize(obj, location): | 
 |     if location == 'meta': | 
 |         return torch.UntypedStorage(obj.nbytes(), device='meta') | 
 |  | 
 |  | 
 | def _validate_privateuse1_device(location, backend_name): | 
 |     ''' | 
 |     Check whether the device index of privateuse1 is valid | 
 |  | 
 |     Register a device_module of privateuse1 by torch._register_device_module. | 
 |     Implement the following methods in device_module like cuda: | 
 |     device_module._utils._get_device_index(location, True), | 
 |     device_module.device_count(). | 
 |  | 
 |     Args: | 
 |         location: string of device | 
 |         backend_name: the name of privateuse1, which can be renamed | 
 |  | 
 |     Returns: | 
 |         device_index: int | 
 |     ''' | 
 |     if not hasattr(torch, backend_name): | 
 |         raise RuntimeError(f'The {backend_name.upper()} device module is not registered. ' | 
 |                            'If you are running on a CPU-only machine, ' | 
 |                            'please use torch.load with map_location=torch.device(\'cpu\') ' | 
 |                            'to map your storages to the CPU.') | 
 |     device_module = getattr(torch, backend_name) | 
 |     if hasattr(device_module, '_utils') and hasattr(device_module._utils, '_get_device_index'): | 
 |         device_index = device_module._utils._get_device_index(location, True) | 
 |     else: | 
 |         device = torch.device(location) | 
 |         device_index = device.index if device.index else 0 | 
 |     if hasattr(device_module, 'is_available') and not device_module.is_available(): | 
 |         raise RuntimeError(f'Attempting to deserialize object on a {backend_name.upper()} ' | 
 |                            f'device but torch.{backend_name}.is_available() is False. ' | 
 |                            'If you are running on a CPU-only machine, ' | 
 |                            'please use torch.load with map_location=torch.device(\'cpu\') ' | 
 |                            'to map your storages to the CPU.') | 
 |     if hasattr(device_module, 'device_count'): | 
 |         device_count = device_module.device_count() | 
 |         if device_index >= device_count: | 
 |             raise RuntimeError(f'Attempting to deserialize object on {backend_name.upper()} device ' | 
 |                                f'{device_index} but torch.{backend_name}.device_count() is {device_count}. ' | 
 |                                'Please use torch.load with map_location to map your storages ' | 
 |                                'to an existing device.') | 
 |     return device_index | 
 |  | 
 |  | 
 | def _privateuse1_deserialize(obj, location): | 
 |     backend_name = torch._C._get_privateuse1_backend_name() | 
 |     if location.startswith(backend_name): | 
 |         if not hasattr(obj, backend_name): | 
 |             raise RuntimeError(f'Attempting to load the storages to the {backend_name.upper()} device ' | 
 |                                f'but torch.storage._StorageBase.{backend_name}() or ' | 
 |                                f'torch.storage.TypedStorage.{backend_name}() is not generated. ' | 
 |                                'Please use torch.utils.generate_methods_for_privateuse1_backend ' | 
 |                                f'to generate storage.{backend_name}() method first.') | 
 |         device_index = _validate_privateuse1_device(location, backend_name) | 
 |         return getattr(obj, backend_name)(device_index) | 
 |  | 
 |  | 
 | register_package(10, _cpu_tag, _cpu_deserialize) | 
 | register_package(20, _cuda_tag, _cuda_deserialize) | 
 | register_package(21, _mps_tag, _mps_deserialize) | 
 | register_package(22, _meta_tag, _meta_deserialize) | 
 | register_package(23, _privateuse1_tag, _privateuse1_deserialize) | 
 | register_package(24, _hpu_tag, _hpu_deserialize) | 
 |  | 
 |  | 
 | def location_tag(storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]): | 
 |     for _, tagger, _ in _package_registry: | 
 |         location = tagger(storage) | 
 |         if location: | 
 |             return location | 
 |     raise RuntimeError("don't know how to determine data location of " | 
 |                        + torch.typename(storage)) | 
 |  | 
 |  | 
 | def default_restore_location(storage, location): | 
 |     for _, _, fn in _package_registry: | 
 |         result = fn(storage, location) | 
 |         if result is not None: | 
 |             return result | 
 |     raise RuntimeError("don't know how to restore data location of " | 
 |                        + torch.typename(storage) + " (tagged with " | 
 |                        + location + ")") | 
 |  | 
 |  | 
 | def normalize_storage_type(storage_type): | 
 |     return getattr(torch, storage_type.__name__) | 
 |  | 
 |  | 
 | def storage_to_tensor_type(storage): | 
 |     storage_type = type(storage) | 
 |     module = _import_dotted_name(storage_type.__module__) | 
 |     return getattr(module, storage_type.__name__.replace('Storage', 'Tensor')) | 
 |  | 
 |  | 
 | def _is_path(name_or_buffer): | 
 |     return isinstance(name_or_buffer, (str, pathlib.Path)) | 
 |  | 
 |  | 
 | class _opener: | 
 |     def __init__(self, file_like): | 
 |         self.file_like = file_like | 
 |  | 
 |     def __enter__(self): | 
 |         return self.file_like | 
 |  | 
 |     def __exit__(self, *args): | 
 |         pass | 
 |  | 
 |  | 
 | class _open_file(_opener): | 
 |     def __init__(self, name, mode): | 
 |         super().__init__(open(name, mode)) | 
 |  | 
 |     def __exit__(self, *args): | 
 |         self.file_like.close() | 
 |  | 
 |  | 
 | class _open_buffer_reader(_opener): | 
 |     def __init__(self, buffer): | 
 |         super().__init__(buffer) | 
 |         _check_seekable(buffer) | 
 |  | 
 |  | 
 | class _open_buffer_writer(_opener): | 
 |     def __exit__(self, *args): | 
 |         self.file_like.flush() | 
 |  | 
 |  | 
 | def _open_file_like(name_or_buffer, mode): | 
 |     if _is_path(name_or_buffer): | 
 |         return _open_file(name_or_buffer, mode) | 
 |     else: | 
 |         if 'w' in mode: | 
 |             return _open_buffer_writer(name_or_buffer) | 
 |         elif 'r' in mode: | 
 |             return _open_buffer_reader(name_or_buffer) | 
 |         else: | 
 |             raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}") | 
 |  | 
 |  | 
 | class _open_zipfile_reader(_opener): | 
 |     def __init__(self, name_or_buffer) -> None: | 
 |         super().__init__(torch._C.PyTorchFileReader(name_or_buffer)) | 
 |  | 
 |  | 
 | class _open_zipfile_writer_file(_opener): | 
 |     def __init__(self, name) -> None: | 
 |         self.file_stream = None | 
 |         self.name = str(name) | 
 |         try: | 
 |             self.name.encode('ascii') | 
 |         except UnicodeEncodeError: | 
 |             # PyTorchFileWriter only supports ascii filename. | 
 |             # For filenames with non-ascii characters, we rely on Python | 
 |             # for writing out the file. | 
 |             self.file_stream = io.FileIO(self.name, mode='w') | 
 |             super().__init__(torch._C.PyTorchFileWriter(self.file_stream)) | 
 |         else: | 
 |             super().__init__(torch._C.PyTorchFileWriter(self.name)) | 
 |  | 
 |     def __exit__(self, *args) -> None: | 
 |         self.file_like.write_end_of_file() | 
 |         if self.file_stream is not None: | 
 |             self.file_stream.close() | 
 |  | 
 |  | 
 | class _open_zipfile_writer_buffer(_opener): | 
 |     def __init__(self, buffer) -> None: | 
 |         if not callable(getattr(buffer, "write", None)): | 
 |             msg = f"Buffer of {str(type(buffer)).strip('<>')} has no callable attribute 'write'" | 
 |             if not hasattr(buffer, "write"): | 
 |                 raise AttributeError(msg) | 
 |             raise TypeError(msg) | 
 |         self.buffer = buffer | 
 |         super().__init__(torch._C.PyTorchFileWriter(buffer)) | 
 |  | 
 |     def __exit__(self, *args) -> None: | 
 |         self.file_like.write_end_of_file() | 
 |         self.buffer.flush() | 
 |  | 
 |  | 
 | def _open_zipfile_writer(name_or_buffer): | 
 |     container: Type[_opener] | 
 |     if _is_path(name_or_buffer): | 
 |         container = _open_zipfile_writer_file | 
 |     else: | 
 |         container = _open_zipfile_writer_buffer | 
 |     return container(name_or_buffer) | 
 |  | 
 |  | 
 | def _is_compressed_file(f) -> bool: | 
 |     compress_modules = ['gzip'] | 
 |     try: | 
 |         return f.__module__ in compress_modules | 
 |     except AttributeError: | 
 |         return False | 
 |  | 
 |  | 
 | def _should_read_directly(f): | 
 |     """ | 
 |     Checks if f is a file that should be read directly. It should be read | 
 |     directly if it is backed by a real file (has a fileno) and is not a | 
 |     a compressed file (e.g. gzip) | 
 |     """ | 
 |     if _is_compressed_file(f): | 
 |         return False | 
 |     try: | 
 |         return f.fileno() >= 0 | 
 |     except io.UnsupportedOperation: | 
 |         return False | 
 |     except AttributeError: | 
 |         return False | 
 |  | 
 |  | 
 | def _check_seekable(f) -> bool: | 
 |  | 
 |     def raise_err_msg(patterns, e): | 
 |         for p in patterns: | 
 |             if p in str(e): | 
 |                 msg = (str(e) + ". You can only torch.load from a file that is seekable." | 
 |                                 + " Please pre-load the data into a buffer like io.BytesIO and" | 
 |                                 + " try to load from it instead.") | 
 |                 raise type(e)(msg) | 
 |         raise e | 
 |  | 
 |     try: | 
 |         f.seek(f.tell()) | 
 |         return True | 
 |     except (io.UnsupportedOperation, AttributeError) as e: | 
 |         raise_err_msg(["seek", "tell"], e) | 
 |     return False | 
 |  | 
 |  | 
 | def _check_dill_version(pickle_module) -> None: | 
 |     '''Checks if using dill as the pickle module, and if so, checks if it is the correct version. | 
 |     If dill version is lower than 0.3.1, a ValueError is raised. | 
 |  | 
 |     Args: | 
 |         pickle_module: module used for pickling metadata and objects | 
 |  | 
 |     ''' | 
 |     if pickle_module is not None and pickle_module.__name__ == 'dill': | 
 |         required_dill_version = (0, 3, 1) | 
 |         if not check_module_version_greater_or_equal(pickle_module, required_dill_version, False): | 
 |             raise ValueError(( | 
 |                 "'torch' supports dill >= {}, but you have dill {}." | 
 |                 " Please upgrade dill or switch to 'pickle'" | 
 |             ).format( | 
 |                 '.'.join([str(num) for num in required_dill_version]), | 
 |                 pickle_module.__version__ | 
 |             )) | 
 |  | 
 |  | 
 | def _check_save_filelike(f): | 
 |     if not isinstance(f, (str, os.PathLike)) and not hasattr(f, 'write'): | 
 |         raise AttributeError( | 
 |             "expected 'f' to be string, path, or a file-like object with " | 
 |             "a 'write' attribute") | 
 |  | 
 |  | 
 | def save( | 
 |     obj: object, | 
 |     f: FILE_LIKE, | 
 |     pickle_module: Any = pickle, | 
 |     pickle_protocol: int = DEFAULT_PROTOCOL, | 
 |     _use_new_zipfile_serialization: bool = True, | 
 |     _disable_byteorder_record: bool = False | 
 | ) -> None: | 
 |     # Reference: https://github.com/pytorch/pytorch/issues/54354 | 
 |     # The first line of this docstring overrides the one Sphinx generates for the | 
 |     # documentation. We need it so that Sphinx doesn't leak `pickle`s path from | 
 |     # the build environment (e.g. `<module 'pickle' from '/leaked/path'). | 
 |  | 
 |     """save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) | 
 |  | 
 |     Saves an object to a disk file. | 
 |  | 
 |     See also: :ref:`saving-loading-tensors` | 
 |  | 
 |     Args: | 
 |         obj: saved object | 
 |         f: a file-like object (has to implement write and flush) or a string or | 
 |            os.PathLike object containing a file name | 
 |         pickle_module: module used for pickling metadata and objects | 
 |         pickle_protocol: can be specified to override the default protocol | 
 |  | 
 |     .. note:: | 
 |         A common PyTorch convention is to save tensors using .pt file extension. | 
 |  | 
 |     .. note:: | 
 |         PyTorch preserves storage sharing across serialization. See | 
 |         :ref:`preserve-storage-sharing` for more details. | 
 |  | 
 |     .. note:: | 
 |         The 1.6 release of PyTorch switched ``torch.save`` to use a new | 
 |         zipfile-based file format. ``torch.load`` still retains the ability to | 
 |         load files in the old format. If for any reason you want ``torch.save`` | 
 |         to use the old format, pass the kwarg ``_use_new_zipfile_serialization=False``. | 
 |  | 
 |     Example: | 
 |         >>> # xdoctest: +SKIP("makes cwd dirty") | 
 |         >>> # Save to file | 
 |         >>> x = torch.tensor([0, 1, 2, 3, 4]) | 
 |         >>> torch.save(x, 'tensor.pt') | 
 |         >>> # Save to io.BytesIO buffer | 
 |         >>> buffer = io.BytesIO() | 
 |         >>> torch.save(x, buffer) | 
 |     """ | 
 |     torch._C._log_api_usage_once("torch.save") | 
 |     _check_dill_version(pickle_module) | 
 |     _check_save_filelike(f) | 
 |  | 
 |     if _use_new_zipfile_serialization: | 
 |         with _open_zipfile_writer(f) as opened_zipfile: | 
 |             _save(obj, opened_zipfile, pickle_module, pickle_protocol, _disable_byteorder_record) | 
 |             return | 
 |     else: | 
 |         with _open_file_like(f, 'wb') as opened_file: | 
 |             _legacy_save(obj, opened_file, pickle_module, pickle_protocol) | 
 |  | 
 |  | 
 | def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None: | 
 |     import torch.nn as nn | 
 |     serialized_container_types = {} | 
 |     serialized_storages = {} | 
 |  | 
 |     # Since loading storages that view the same data with different dtypes is | 
 |     # not supported, we need to keep track of the dtype associated with each | 
 |     # storage data_ptr and throw an error if the dtype is ever different. | 
 |     # TODO: This feature could be added in the future | 
 |     storage_dtypes: Dict[int, torch.dtype] = {} | 
 |  | 
 |     def persistent_id(obj: Any) -> Optional[Tuple]: | 
 |         # FIXME: the docs say that persistent_id should only return a string | 
 |         # but torch store returns tuples. This works only in the binary protocol | 
 |         # see | 
 |         # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects | 
 |         # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 | 
 |         if isinstance(obj, type) and issubclass(obj, nn.Module): | 
 |             if obj in serialized_container_types: | 
 |                 return None | 
 |             serialized_container_types[obj] = True | 
 |             source_file = source = None | 
 |             try: | 
 |                 source_lines, _, source_file = get_source_lines_and_file(obj) | 
 |                 source = ''.join(source_lines) | 
 |             except Exception:  # saving the source is optional, so we can ignore any errors | 
 |                 warnings.warn("Couldn't retrieve source code for container of " | 
 |                               "type " + obj.__name__ + ". It won't be checked " | 
 |                               "for correctness upon loading.") | 
 |             return ('module', obj, source_file, source) | 
 |  | 
 |         if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): | 
 |             storage: torch.UntypedStorage | 
 |  | 
 |             if isinstance(obj, torch.storage.TypedStorage): | 
 |                 # TODO: Once we decide to break serialization FC, this case | 
 |                 # can be deleted | 
 |                 storage = obj._untyped_storage | 
 |                 storage_dtype = obj.dtype | 
 |                 storage_type_str = obj._pickle_storage_type() | 
 |                 storage_type = getattr(torch, storage_type_str) | 
 |                 dtype = obj.dtype | 
 |                 storage_numel = obj._size() | 
 |  | 
 |             elif isinstance(obj, torch.UntypedStorage): | 
 |                 storage = obj | 
 |                 storage_dtype = torch.uint8 | 
 |                 storage_type = normalize_storage_type(type(obj)) | 
 |                 dtype = torch.uint8 | 
 |                 storage_numel = storage.nbytes() | 
 |             else: | 
 |                 raise TypeError(f'type not recognized: {type(obj)}') | 
 |  | 
 |             # If storage is allocated, ensure that any other saved storages | 
 |             # pointing to the same data all have the same dtype. If storage is | 
 |             # not allocated, don't perform this check | 
 |             if storage.data_ptr() != 0: | 
 |                 if storage.data_ptr() in storage_dtypes: | 
 |                     if storage_dtype != storage_dtypes[storage.data_ptr()]: | 
 |                         raise RuntimeError( | 
 |                             'Cannot save multiple tensors or storages that ' | 
 |                             'view the same data as different types') | 
 |                 else: | 
 |                     storage_dtypes[storage.data_ptr()] = storage_dtype | 
 |  | 
 |             view_metadata: Optional[Tuple[str, int, int]] | 
 |  | 
 |             # Offset is always 0, but we keep it for backwards compatibility | 
 |             # with the old serialization format (which supported storage views) | 
 |             offset = 0 | 
 |             storage_key = str(storage._cdata) | 
 |             location = location_tag(storage) | 
 |  | 
 |             # TODO: There's an issue here with FC. It might be impossible to | 
 |             # solve, but it's worth noting. Imagine we save a list `[storage, | 
 |             # tensor]`, where `tensor.storage()` is the same as `storage`, and | 
 |             # `tensor.element_size() > 1`. Let's say that `tensor.dtype == | 
 |             # torch.float`.  The storage will be serialized with element size | 
 |             # of 1, since we're choosing to serialize the first occurance of | 
 |             # a duplicate storage. Since this legacy serialization format saves | 
 |             # the numel of the storage, rather than nbytes directly, we'll be | 
 |             # effectively saving nbytes in this case.  We'll be able to load it | 
 |             # and the tensor back up with no problems in _this_ and future | 
 |             # versions of pytorch, but in older versions, here's the problem: | 
 |             # the storage will be loaded up as a UntypedStorage, and then the | 
 |             # FloatTensor will loaded and the UntypedStorage will be assigned to | 
 |             # it. Since the storage dtype does not match the tensor dtype, this | 
 |             # will cause an error.  If we reverse the list, like `[tensor, | 
 |             # storage]`, then we will save the `tensor.storage()` as a faked | 
 |             # `FloatStorage`, and the saved size will be the correct | 
 |             # dtype-specific numel count that old versions expect. `tensor` | 
 |             # will be able to load up properly in old versions, pointing to | 
 |             # a FloatStorage. However, `storage` is still being translated to | 
 |             # a UntypedStorage, and it will try to resolve to the same | 
 |             # FloatStorage that `tensor` contains. This will also cause an | 
 |             # error. It doesn't seem like there's any way around this. | 
 |             # Probably, we just cannot maintain FC for the legacy format if the | 
 |             # saved list contains both a tensor and a storage that point to the | 
 |             # same data.  We should still be able to maintain FC for lists of | 
 |             # just tensors, as long as all views share the same dtype as the | 
 |             # tensor they are viewing. | 
 |  | 
 |             if storage_key not in serialized_storages: | 
 |                 serialized_storages[storage_key] = (storage, dtype) | 
 |             is_view = storage._cdata != storage._cdata | 
 |             if is_view: | 
 |                 view_metadata = (str(storage._cdata), offset, storage.nbytes()) | 
 |             else: | 
 |                 view_metadata = None | 
 |  | 
 |             res = ('storage', | 
 |                    storage_type, | 
 |                    storage_key, | 
 |                    location, | 
 |                    storage_numel, | 
 |                    view_metadata) | 
 |             return res | 
 |         return None | 
 |  | 
 |     sys_info = dict( | 
 |         protocol_version=PROTOCOL_VERSION, | 
 |         little_endian=sys.byteorder == 'little', | 
 |         type_sizes=dict( | 
 |             short=SHORT_SIZE, | 
 |             int=INT_SIZE, | 
 |             long=LONG_SIZE, | 
 |         ), | 
 |     ) | 
 |  | 
 |     pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol) | 
 |     pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol) | 
 |     pickle_module.dump(sys_info, f, protocol=pickle_protocol) | 
 |     pickler = pickle_module.Pickler(f, protocol=pickle_protocol) | 
 |     pickler.persistent_id = persistent_id | 
 |     pickler.dump(obj) | 
 |  | 
 |     serialized_storage_keys = sorted(serialized_storages.keys()) | 
 |     pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol) | 
 |     f.flush() | 
 |     for key in serialized_storage_keys: | 
 |         storage, dtype = serialized_storages[key] | 
 |         storage._write_file(f, _should_read_directly(f), True, torch._utils._element_size(dtype)) | 
 |  | 
 |  | 
 | def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record): | 
 |     serialized_storages = {} | 
 |     id_map: Dict[int, str] = {} | 
 |  | 
 |     # Since loading storages that view the same data with different dtypes is | 
 |     # not supported, we need to keep track of the dtype associated with each | 
 |     # storage data_ptr and throw an error if the dtype is ever different. | 
 |     # TODO: This feature could be added in the future | 
 |     storage_dtypes: Dict[int, torch.dtype] = {} | 
 |  | 
 |     def persistent_id(obj): | 
 |         # FIXME: the docs say that persistent_id should only return a string | 
 |         # but torch store returns tuples. This works only in the binary protocol | 
 |         # see | 
 |         # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects | 
 |         # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 | 
 |         if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): | 
 |  | 
 |             if isinstance(obj, torch.storage.TypedStorage): | 
 |                 # TODO: Once we decide to break serialization FC, this case | 
 |                 # can be deleted | 
 |                 storage = obj._untyped_storage | 
 |                 storage_dtype = obj.dtype | 
 |                 storage_type_str = obj._pickle_storage_type() | 
 |                 storage_type = getattr(torch, storage_type_str) | 
 |                 storage_numel = obj._size() | 
 |  | 
 |             else: | 
 |                 storage = obj | 
 |                 storage_dtype = torch.uint8 | 
 |                 storage_type = normalize_storage_type(type(obj)) | 
 |                 storage_numel = storage.nbytes() | 
 |  | 
 |             # If storage is allocated, ensure that any other saved storages | 
 |             # pointing to the same data all have the same dtype. If storage is | 
 |             # not allocated, don't perform this check | 
 |             if storage.data_ptr() != 0: | 
 |                 if storage.data_ptr() in storage_dtypes: | 
 |                     if storage_dtype != storage_dtypes[storage.data_ptr()]: | 
 |                         raise RuntimeError( | 
 |                             'Cannot save multiple tensors or storages that ' | 
 |                             'view the same data as different types') | 
 |                 else: | 
 |                     storage_dtypes[storage.data_ptr()] = storage_dtype | 
 |  | 
 |             storage_key = id_map.setdefault(storage._cdata, str(len(id_map))) | 
 |             location = location_tag(storage) | 
 |             serialized_storages[storage_key] = storage | 
 |  | 
 |             return ('storage', | 
 |                     storage_type, | 
 |                     storage_key, | 
 |                     location, | 
 |                     storage_numel) | 
 |  | 
 |         return None | 
 |  | 
 |     # Write the pickle data for `obj` | 
 |     data_buf = io.BytesIO() | 
 |     pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol) | 
 |     pickler.persistent_id = persistent_id | 
 |     pickler.dump(obj) | 
 |     data_value = data_buf.getvalue() | 
 |     zip_file.write_record('data.pkl', data_value, len(data_value)) | 
 |  | 
 |     # Write byte order marker | 
 |     if not _disable_byteorder_record: | 
 |         if sys.byteorder not in ['little', 'big']: | 
 |             raise ValueError('Unknown endianness type: ' + sys.byteorder) | 
 |  | 
 |         zip_file.write_record('byteorder', sys.byteorder, len(sys.byteorder)) | 
 |  | 
 |     # Write each tensor to a file named tensor/the_tensor_key in the zip archive | 
 |     for key in sorted(serialized_storages.keys()): | 
 |         name = f'data/{key}' | 
 |         storage = serialized_storages[key] | 
 |         # given that we copy things around anyway, we might use storage.cpu() | 
 |         # this means to that to get tensors serialized, you need to implement | 
 |         # .cpu() on the underlying Storage | 
 |         if storage.device.type != 'cpu': | 
 |             storage = storage.cpu() | 
 |         # Now that it is on the CPU we can directly copy it into the zip file | 
 |         num_bytes = storage.nbytes() | 
 |         zip_file.write_record(name, storage.data_ptr(), num_bytes) | 
 |  | 
 |  | 
 | def load( | 
 |     f: FILE_LIKE, | 
 |     map_location: MAP_LOCATION = None, | 
 |     pickle_module: Any = None, | 
 |     *, | 
 |     weights_only: bool = False, | 
 |     mmap: Optional[bool] = None, | 
 |     **pickle_load_args: Any | 
 | ) -> Any: | 
 |     # Reference: https://github.com/pytorch/pytorch/issues/54354 | 
 |     # The first line of this docstring overrides the one Sphinx generates for the | 
 |     # documentation. We need it so that Sphinx doesn't leak `pickle`s path from | 
 |     # the build environment (e.g. `<module 'pickle' from '/leaked/path'). | 
 |  | 
 |     """load(f, map_location=None, pickle_module=pickle, *, weights_only=False, mmap=None, **pickle_load_args) | 
 |  | 
 |     Loads an object saved with :func:`torch.save` from a file. | 
 |  | 
 |     :func:`torch.load` uses Python's unpickling facilities but treats storages, | 
 |     which underlie tensors, specially. They are first deserialized on the | 
 |     CPU and are then moved to the device they were saved from. If this fails | 
 |     (e.g. because the run time system doesn't have certain devices), an exception | 
 |     is raised. However, storages can be dynamically remapped to an alternative | 
 |     set of devices using the :attr:`map_location` argument. | 
 |  | 
 |     If :attr:`map_location` is a callable, it will be called once for each serialized | 
 |     storage with two arguments: storage and location. The storage argument | 
 |     will be the initial deserialization of the storage, residing on the CPU. | 
 |     Each serialized storage has a location tag associated with it which | 
 |     identifies the device it was saved from, and this tag is the second | 
 |     argument passed to :attr:`map_location`. The builtin location tags are ``'cpu'`` | 
 |     for CPU tensors and ``'cuda:device_id'`` (e.g. ``'cuda:2'``) for CUDA tensors. | 
 |     :attr:`map_location` should return either ``None`` or a storage. If | 
 |     :attr:`map_location` returns a storage, it will be used as the final deserialized | 
 |     object, already moved to the right device. Otherwise, :func:`torch.load` will | 
 |     fall back to the default behavior, as if :attr:`map_location` wasn't specified. | 
 |  | 
 |     If :attr:`map_location` is a :class:`torch.device` object or a string containing | 
 |     a device tag, it indicates the location where all tensors should be loaded. | 
 |  | 
 |     Otherwise, if :attr:`map_location` is a dict, it will be used to remap location tags | 
 |     appearing in the file (keys), to ones that specify where to put the | 
 |     storages (values). | 
 |  | 
 |     User extensions can register their own location tags and tagging and | 
 |     deserialization methods using :func:`torch.serialization.register_package`. | 
 |  | 
 |     Args: | 
 |         f: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`), | 
 |             or a string or os.PathLike object containing a file name | 
 |         map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage | 
 |             locations | 
 |         pickle_module: module used for unpickling metadata and objects (has to | 
 |             match the :attr:`pickle_module` used to serialize file) | 
 |         weights_only: Indicates whether unpickler should be restricted to | 
 |             loading only tensors, primitive types and dictionaries | 
 |         mmap: Indicates whether the file should be mmaped rather than loading all the storages into memory. | 
 |             Typically, tensor storages in the file will first be moved from disk to CPU memory, after which they | 
 |             are moved to the location that they were tagged with when saving, or specified by ``map_location``. This | 
 |             second step is a no-op if the final location is CPU. When the ``mmap`` flag is set, instead of copying the | 
 |             tensor storages from disk to CPU memory in the first step, ``f`` is mmaped. | 
 |         pickle_load_args: (Python 3 only) optional keyword arguments passed over to | 
 |             :func:`pickle_module.load` and :func:`pickle_module.Unpickler`, e.g., | 
 |             :attr:`errors=...`. | 
 |  | 
 |     .. warning:: | 
 |         :func:`torch.load()` unless `weights_only` parameter is set to `True`, | 
 |         uses ``pickle`` module implicitly, which is known to be insecure. | 
 |         It is possible to construct malicious pickle data which will execute arbitrary code | 
 |         during unpickling. Never load data that could have come from an untrusted | 
 |         source in an unsafe mode, or that could have been tampered with. **Only load data you trust**. | 
 |  | 
 |     .. note:: | 
 |         When you call :func:`torch.load()` on a file which contains GPU tensors, those tensors | 
 |         will be loaded to GPU by default. You can call ``torch.load(.., map_location='cpu')`` | 
 |         and then :meth:`load_state_dict` to avoid GPU RAM surge when loading a model checkpoint. | 
 |  | 
 |     .. note:: | 
 |         By default, we decode byte strings as ``utf-8``.  This is to avoid a common error | 
 |         case ``UnicodeDecodeError: 'ascii' codec can't decode byte 0x...`` | 
 |         when loading files saved by Python 2 in Python 3.  If this default | 
 |         is incorrect, you may use an extra :attr:`encoding` keyword argument to specify how | 
 |         these objects should be loaded, e.g., :attr:`encoding='latin1'` decodes them | 
 |         to strings using ``latin1`` encoding, and :attr:`encoding='bytes'` keeps them | 
 |         as byte arrays which can be decoded later with ``byte_array.decode(...)``. | 
 |  | 
 |     Example: | 
 |         >>> # xdoctest: +SKIP("undefined filepaths") | 
 |         >>> torch.load('tensors.pt', weights_only=True) | 
 |         # Load all tensors onto the CPU | 
 |         >>> torch.load('tensors.pt', map_location=torch.device('cpu'), weights_only=True) | 
 |         # Load all tensors onto the CPU, using a function | 
 |         >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage, weights_only=True) | 
 |         # Load all tensors onto GPU 1 | 
 |         >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1), weights_only=True) | 
 |         # Map tensors from GPU 1 to GPU 0 | 
 |         >>> torch.load('tensors.pt', map_location={'cuda:1': 'cuda:0'}, weights_only=True) | 
 |         # Load tensor from io.BytesIO object | 
 |         # Loading from a buffer setting weights_only=False, warning this can be unsafe | 
 |         >>> with open('tensor.pt', 'rb') as f: | 
 |         ...     buffer = io.BytesIO(f.read()) | 
 |         >>> torch.load(buffer, weights_only=False) | 
 |         # Load a module with 'ascii' encoding for unpickling | 
 |         # Loading from a module setting weights_only=False, warning this can be unsafe | 
 |         >>> torch.load('module.pt', encoding='ascii', weights_only=False) | 
 |     """ | 
 |     torch._C._log_api_usage_once("torch.load") | 
 |     UNSAFE_MESSAGE = ( | 
 |         "Weights only load failed. Re-running `torch.load` with `weights_only` set to `False`" | 
 |         " will likely succeed, but it can result in arbitrary code execution." | 
 |         "Do it only if you get the file from a trusted source. WeightsUnpickler error: " | 
 |     ) | 
 |     # Add ability to force safe only weight loads via environment variable | 
 |     if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ['1', 'y', 'yes', 'true']: | 
 |         weights_only = True | 
 |  | 
 |     if weights_only: | 
 |         if pickle_module is not None: | 
 |             raise RuntimeError("Can not safely load weights when explicit pickle_module is specified") | 
 |     else: | 
 |         if pickle_module is None: | 
 |             pickle_module = pickle | 
 |  | 
 |     # make flipping default BC-compatible | 
 |     if mmap is None: | 
 |         mmap = False | 
 |  | 
 |     _check_dill_version(pickle_module) | 
 |  | 
 |     if 'encoding' not in pickle_load_args.keys(): | 
 |         pickle_load_args['encoding'] = 'utf-8' | 
 |  | 
 |     with _open_file_like(f, 'rb') as opened_file: | 
 |         if _is_zipfile(opened_file): | 
 |             # The zipfile reader is going to advance the current file position. | 
 |             # If we want to actually tail call to torch.jit.load, we need to | 
 |             # reset back to the original position. | 
 |             orig_position = opened_file.tell() | 
 |             overall_storage = None | 
 |             with _open_zipfile_reader(opened_file) as opened_zipfile: | 
 |                 if _is_torchscript_zip(opened_zipfile): | 
 |                     warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive" | 
 |                                   " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to" | 
 |                                   " silence this warning)", UserWarning) | 
 |                     opened_file.seek(orig_position) | 
 |                     return torch.jit.load(opened_file, map_location=map_location) | 
 |                 if mmap: | 
 |                     if not isinstance(f, str): | 
 |                         raise ValueError("f must be a string filename in order to use mmap argument") | 
 |                     size = os.path.getsize(f) | 
 |                     overall_storage = torch.UntypedStorage.from_file(f, False, size) | 
 |                 if weights_only: | 
 |                     try: | 
 |                         return _load(opened_zipfile, | 
 |                                      map_location, | 
 |                                      _weights_only_unpickler, | 
 |                                      overall_storage=overall_storage, | 
 |                                      **pickle_load_args) | 
 |                     except RuntimeError as e: | 
 |                         raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None | 
 |                 return _load(opened_zipfile, | 
 |                              map_location, | 
 |                              pickle_module, | 
 |                              overall_storage=overall_storage, | 
 |                              **pickle_load_args) | 
 |         if mmap: | 
 |             raise RuntimeError("mmap can only be used with files saved with " | 
 |                                "`torch.save(_use_new_zipfile_serialization=True), " | 
 |                                "please torch.save your checkpoint with this option in order to use mmap.") | 
 |         if weights_only: | 
 |             try: | 
 |                 return _legacy_load(opened_file, map_location, _weights_only_unpickler, **pickle_load_args) | 
 |             except RuntimeError as e: | 
 |                 raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None | 
 |         return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args) | 
 |  | 
 |  | 
 | # Register pickling support for layout instances such as | 
 | # torch.sparse_coo, etc | 
 | def _get_layout(name): | 
 |     """Get layout extension object from its string representation. | 
 |     """ | 
 |     cache = _get_layout.cache   # type: ignore[attr-defined] | 
 |     if not cache: | 
 |         for v in torch.__dict__.values(): | 
 |             if isinstance(v, torch.layout): | 
 |                 cache[str(v)] = v | 
 |     return cache[name] | 
 |  | 
 | # There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087 | 
 | _get_layout.cache = {}   # type: ignore[attr-defined] | 
 | copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),))) | 
 |  | 
 |  | 
 | def _legacy_load(f, map_location, pickle_module, **pickle_load_args): | 
 |     deserialized_objects: Dict[int, Any] = {} | 
 |  | 
 |     restore_location = _get_restore_location(map_location) | 
 |  | 
 |     class UnpicklerWrapper(pickle_module.Unpickler):  # type: ignore[name-defined] | 
 |  | 
 |         def find_class(self, mod_name, name): | 
 |             if type(name) is str and 'Storage' in name: | 
 |                 try: | 
 |                     return StorageType(name) | 
 |                 except KeyError: | 
 |                     pass | 
 |             return super().find_class(mod_name, name) | 
 |  | 
 |     def _check_container_source(container_type, source_file, original_source): | 
 |         try: | 
 |             current_source = ''.join(get_source_lines_and_file(container_type)[0]) | 
 |         except Exception:  # saving the source is optional, so we can ignore any errors | 
 |             warnings.warn("Couldn't retrieve source code for container of " | 
 |                           "type " + container_type.__name__ + ". It won't be checked " | 
 |                           "for correctness upon loading.") | 
 |             return | 
 |         if original_source != current_source: | 
 |             if container_type.dump_patches: | 
 |                 file_name = container_type.__name__ + '.patch' | 
 |                 diff = difflib.unified_diff(current_source.split('\n'), | 
 |                                             original_source.split('\n'), | 
 |                                             source_file, | 
 |                                             source_file, lineterm="") | 
 |                 lines = '\n'.join(diff) | 
 |                 try: | 
 |                     with open(file_name, 'a+') as f: | 
 |                         file_size = f.seek(0, 2) | 
 |                         f.seek(0) | 
 |                         if file_size == 0: | 
 |                             f.write(lines) | 
 |                         elif file_size != len(lines) or f.read() != lines: | 
 |                             raise OSError | 
 |                     msg = ("Saved a reverse patch to " + file_name + ". " | 
 |                            "Run `patch -p0 < " + file_name + "` to revert your " | 
 |                            "changes.") | 
 |                 except OSError: | 
 |                     msg = ("Tried to save a patch, but couldn't create a " | 
 |                            "writable file " + file_name + ". Make sure it " | 
 |                            "doesn't exist and your working directory is " | 
 |                            "writable.") | 
 |             else: | 
 |                 msg = ("you can retrieve the original source code by " | 
 |                        "accessing the object's source attribute or set " | 
 |                        "`torch.nn.Module.dump_patches = True` and use the " | 
 |                        "patch tool to revert the changes.") | 
 |             msg = f"source code of class '{torch.typename(container_type)}' has changed. {msg}" | 
 |             warnings.warn(msg, SourceChangeWarning) | 
 |  | 
 |     def legacy_load(f): | 
 |         deserialized_objects: Dict[int, Any] = {} | 
 |  | 
 |         def persistent_load(saved_id): | 
 |             if isinstance(saved_id, tuple): | 
 |                 # Ignore containers that don't have any sources saved | 
 |                 if all(saved_id[1:]): | 
 |                     _check_container_source(*saved_id) | 
 |                 return saved_id[0] | 
 |             return deserialized_objects[int(saved_id)] | 
 |  | 
 |         with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \ | 
 |                 mkdtemp() as tmpdir: | 
 |  | 
 |             tar.extract('storages', path=tmpdir) | 
 |             with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f: | 
 |                 num_storages = pickle_module.load(f, **pickle_load_args) | 
 |                 for i in range(num_storages): | 
 |                     args = pickle_module.load(f, **pickle_load_args) | 
 |                     key, location, storage_type = args | 
 |                     dtype = storage_type._dtype | 
 |                     obj = cast(Storage, torch.UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype)) | 
 |                     obj = restore_location(obj, location) | 
 |                     # TODO: Once we decide to break serialization FC, we can | 
 |                     # stop wrapping with TypedStorage | 
 |                     deserialized_objects[key] = torch.storage.TypedStorage( | 
 |                         wrap_storage=obj, | 
 |                         dtype=dtype, | 
 |                         _internal=True) | 
 |  | 
 |                 storage_views = pickle_module.load(f, **pickle_load_args) | 
 |                 for target_cdata, root_cdata, offset, numel in storage_views: | 
 |                     root = deserialized_objects[root_cdata] | 
 |                     element_size = torch._utils._element_size(root.dtype) | 
 |                     offset_bytes = offset * element_size | 
 |                     # TODO: Once we decide to break serialization FC, we can | 
 |                     # stop wrapping with TypedStorage | 
 |                     deserialized_objects[target_cdata] = torch.storage.TypedStorage( | 
 |                         wrap_storage=root._untyped_storage[offset_bytes:offset_bytes + numel * element_size], | 
 |                         dtype=root.dtype, | 
 |                         _internal=True) | 
 |  | 
 |             tar.extract('tensors', path=tmpdir) | 
 |             with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f: | 
 |                 num_tensors = pickle_module.load(f, **pickle_load_args) | 
 |                 for _ in range(num_tensors): | 
 |                     args = pickle_module.load(f, **pickle_load_args) | 
 |                     key, storage_id, original_tensor_type = args | 
 |                     storage = deserialized_objects[storage_id] | 
 |                     ndim, = struct.unpack('<i', f.read(4)) | 
 |                     # skip next 4 bytes; legacy encoding treated ndim as 8 bytes | 
 |                     f.read(4) | 
 |                     numel = struct.unpack(f'<{ndim}q', f.read(8 * ndim)) | 
 |                     stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim)) | 
 |                     storage_offset, = struct.unpack('<q', f.read(8)) | 
 |                     tensor = torch.tensor([], dtype=storage.dtype).set_( | 
 |                         storage._untyped_storage, storage_offset, numel, stride) | 
 |                     deserialized_objects[key] = tensor | 
 |  | 
 |             pickle_file = tar.extractfile('pickle') | 
 |             unpickler = UnpicklerWrapper(pickle_file, **pickle_load_args) | 
 |             unpickler.persistent_load = persistent_load | 
 |             result = unpickler.load() | 
 |             return result | 
 |  | 
 |     deserialized_objects = {} | 
 |  | 
 |     def persistent_load(saved_id): | 
 |         assert isinstance(saved_id, tuple) | 
 |         typename = _maybe_decode_ascii(saved_id[0]) | 
 |         data = saved_id[1:] | 
 |  | 
 |         if typename == 'module': | 
 |             # Ignore containers that don't have any sources saved | 
 |             if all(data[1:]): | 
 |                 _check_container_source(*data) | 
 |             return data[0] | 
 |         elif typename == 'storage': | 
 |             storage_type, root_key, location, numel, view_metadata = data | 
 |             location = _maybe_decode_ascii(location) | 
 |             dtype = storage_type.dtype | 
 |  | 
 |             nbytes = numel * torch._utils._element_size(dtype) | 
 |  | 
 |             if root_key not in deserialized_objects: | 
 |                 obj = cast(Storage, torch.UntypedStorage(nbytes)) | 
 |                 obj._torch_load_uninitialized = True | 
 |                 # TODO: Once we decide to break serialization FC, we can | 
 |                 # stop wrapping with TypedStorage | 
 |                 typed_storage = torch.storage.TypedStorage( | 
 |                     wrap_storage=restore_location(obj, location), | 
 |                     dtype=dtype, | 
 |                     _internal=True) | 
 |                 deserialized_objects[root_key] = typed_storage | 
 |             else: | 
 |                 typed_storage = deserialized_objects[root_key] | 
 |                 if typed_storage._data_ptr() == 0: | 
 |                     typed_storage = torch.storage.TypedStorage( | 
 |                         device=typed_storage._untyped_storage.device, | 
 |                         dtype=dtype, | 
 |                         _internal=True) | 
 |  | 
 |             if view_metadata is not None: | 
 |                 view_key, offset, view_size = view_metadata | 
 |                 offset_bytes = offset * torch._utils._element_size(dtype) | 
 |                 view_size_bytes = view_size * torch._utils._element_size(dtype) | 
 |                 if view_key not in deserialized_objects: | 
 |                     # TODO: Once we decide to break serialization FC, we can | 
 |                     # stop wrapping with TypedStorage | 
 |                     deserialized_objects[view_key] = torch.storage.TypedStorage( | 
 |                         wrap_storage=typed_storage._untyped_storage[offset_bytes:offset_bytes + view_size_bytes], | 
 |                         dtype=dtype, | 
 |                         _internal=True) | 
 |                 res = deserialized_objects[view_key] | 
 |  | 
 |             else: | 
 |                 res = typed_storage | 
 |             return res | 
 |         else: | 
 |             raise RuntimeError(f"Unknown saved id type: {saved_id[0]}") | 
 |  | 
 |     _check_seekable(f) | 
 |     f_should_read_directly = _should_read_directly(f) | 
 |  | 
 |     if f_should_read_directly and f.tell() == 0: | 
 |         # legacy_load requires that f has fileno() | 
 |         # only if offset is zero we can attempt the legacy tar file loader | 
 |         try: | 
 |             return legacy_load(f) | 
 |         except tarfile.TarError: | 
 |             if _is_zipfile(f): | 
 |                 # .zip is used for torch.jit.save and will throw an un-pickling error here | 
 |                 raise RuntimeError( | 
 |                     f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)") from None | 
 |             # if not a tarfile, reset file offset and proceed | 
 |             f.seek(0) | 
 |  | 
 |     if not hasattr(f, 'readinto') and (3, 8, 0) <= sys.version_info < (3, 8, 2): | 
 |         raise RuntimeError( | 
 |             "torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. " | 
 |             f"Received object of type \"{type(f)}\". Please update to Python 3.8.2 or newer to restore this " | 
 |             "functionality.") | 
 |  | 
 |     magic_number = pickle_module.load(f, **pickle_load_args) | 
 |     if magic_number != MAGIC_NUMBER: | 
 |         raise RuntimeError("Invalid magic number; corrupt file?") | 
 |     protocol_version = pickle_module.load(f, **pickle_load_args) | 
 |     if protocol_version != PROTOCOL_VERSION: | 
 |         raise RuntimeError(f"Invalid protocol version: {protocol_version}") | 
 |  | 
 |     _sys_info = pickle_module.load(f, **pickle_load_args) | 
 |     unpickler = UnpicklerWrapper(f, **pickle_load_args) | 
 |     unpickler.persistent_load = persistent_load | 
 |     result = unpickler.load() | 
 |  | 
 |     deserialized_storage_keys = pickle_module.load(f, **pickle_load_args) | 
 |  | 
 |     offset = f.tell() if f_should_read_directly else None | 
 |     for key in deserialized_storage_keys: | 
 |         assert key in deserialized_objects | 
 |         typed_storage = deserialized_objects[key] | 
 |         typed_storage._untyped_storage._set_from_file( | 
 |             f, offset, f_should_read_directly, | 
 |             torch._utils._element_size(typed_storage.dtype)) | 
 |         if offset is not None: | 
 |             offset = f.tell() | 
 |  | 
 |     torch._utils._validate_loaded_sparse_tensors() | 
 |  | 
 |     return result | 
 |  | 
 |  | 
 | def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str: | 
 |     # When using encoding='bytes' in Py3, some **internal** keys stored as | 
 |     # strings in Py2 are loaded as bytes. This function decodes them with | 
 |     # ascii encoding, one that Py3 uses by default. | 
 |     # | 
 |     # NOTE: This should only be used on internal keys (e.g., `typename` and | 
 |     #       `location` in `persistent_load` below! | 
 |     if isinstance(bytes_str, bytes): | 
 |         return bytes_str.decode('ascii') | 
 |     return bytes_str | 
 |  | 
 |  | 
 | def _get_restore_location(map_location): | 
 |     if map_location is None: | 
 |         restore_location = default_restore_location | 
 |     elif isinstance(map_location, dict): | 
 |         def restore_location(storage, location): | 
 |             location = map_location.get(location, location) | 
 |             return default_restore_location(storage, location) | 
 |     elif isinstance(map_location, (str, bytes)): | 
 |         def restore_location(storage, location): | 
 |             return default_restore_location(storage, map_location) | 
 |     elif isinstance(map_location, torch.device): | 
 |         def restore_location(storage, location): | 
 |             return default_restore_location(storage, str(map_location)) | 
 |     else: | 
 |         def restore_location(storage, location): | 
 |             result = map_location(storage, location) | 
 |             if result is None: | 
 |                 result = default_restore_location(storage, location) | 
 |             return result | 
 |     return restore_location | 
 |  | 
 |  | 
 | class StorageType: | 
 |     def __init__(self, name): | 
 |         self._dtype = _get_dtype_from_pickle_storage_type(name) | 
 |  | 
 |     @property | 
 |     def dtype(self): | 
 |         return self._dtype | 
 |  | 
 |     def __str__(self): | 
 |         return f'StorageType(dtype={self.dtype})' | 
 |  | 
 |  | 
 | def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', overall_storage=None, **pickle_load_args): | 
 |     restore_location = _get_restore_location(map_location) | 
 |  | 
 |     loaded_storages = {} | 
 |  | 
 |     # check if byteswapping is needed | 
 |     byteordername = 'byteorder' | 
 |     byteorderdata = None | 
 |     if zip_file.has_record(byteordername): | 
 |         byteorderdata = zip_file.get_record(byteordername) | 
 |         if byteorderdata not in [b'little', b'big']: | 
 |             raise ValueError('Unknown endianness type: ' + byteorderdata.decode()) | 
 |     elif get_default_load_endianness() == LoadEndianness.LITTLE or \ | 
 |             get_default_load_endianness() is None: | 
 |         byteorderdata = b'little' | 
 |     elif get_default_load_endianness() == LoadEndianness.BIG: | 
 |         byteorderdata = b'big' | 
 |     elif get_default_load_endianness() == LoadEndianness.NATIVE: | 
 |         pass | 
 |     else: | 
 |         raise ValueError('Invalid load endianness type') | 
 |  | 
 |     if not zip_file.has_record(byteordername) and \ | 
 |             get_default_load_endianness() is None and \ | 
 |             sys.byteorder == 'big': | 
 |         # Default behaviour was changed | 
 |         # See https://github.com/pytorch/pytorch/issues/101688 | 
 |         warnings.warn("The default load endianness for checkpoints without a byteorder mark " | 
 |                       "on big endian machines was changed from 'native' to 'little' endian, " | 
 |                       "to avoid this behavior please use " | 
 |                       "torch.serialization.set_default_load_endianness to set " | 
 |                       "the desired default load endianness", | 
 |                       UserWarning) | 
 |  | 
 |     def load_tensor(dtype, numel, key, location): | 
 |         name = f'data/{key}' | 
 |         if overall_storage is not None: | 
 |             storage_offset = zip_file.get_record_offset(name) | 
 |             storage = overall_storage[storage_offset:storage_offset + numel] | 
 |         else: | 
 |             storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)._typed_storage()._untyped_storage | 
 |         # swap here if byteswapping is needed | 
 |         if byteorderdata is not None: | 
 |             if byteorderdata.decode() != sys.byteorder: | 
 |                 storage.byteswap(dtype) | 
 |  | 
 |         # TODO: Once we decide to break serialization FC, we can | 
 |         # stop wrapping with TypedStorage | 
 |         typed_storage = torch.storage.TypedStorage( | 
 |             wrap_storage=restore_location(storage, location), | 
 |             dtype=dtype, | 
 |             _internal=True) | 
 |  | 
 |         if typed_storage._data_ptr() != 0: | 
 |             loaded_storages[key] = typed_storage | 
 |  | 
 |         return typed_storage | 
 |  | 
 |     def persistent_load(saved_id): | 
 |         assert isinstance(saved_id, tuple) | 
 |         typename = _maybe_decode_ascii(saved_id[0]) | 
 |         data = saved_id[1:] | 
 |  | 
 |         assert typename == 'storage', \ | 
 |             f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" | 
 |         storage_type, key, location, numel = data | 
 |         if storage_type is torch.UntypedStorage: | 
 |             dtype = torch.uint8 | 
 |         else: | 
 |             dtype = storage_type.dtype | 
 |  | 
 |         if key in loaded_storages: | 
 |             typed_storage = loaded_storages[key] | 
 |         else: | 
 |             nbytes = numel * torch._utils._element_size(dtype) | 
 |             typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location)) | 
 |  | 
 |         return typed_storage | 
 |  | 
 |     load_module_mapping: Dict[str, str] = { | 
 |         # See https://github.com/pytorch/pytorch/pull/51633 | 
 |         'torch.tensor': 'torch._tensor' | 
 |     } | 
 |  | 
 |     # Need to subclass Unpickler instead of directly monkey-patching the find_class method | 
 |     # because it's marked readonly in pickle. | 
 |     # The type: ignore is because mypy can't statically determine the type of this class. | 
 |     class UnpicklerWrapper(pickle_module.Unpickler):  # type: ignore[name-defined] | 
 |         # from https://stackoverflow.com/questions/13398462/unpickling-python-objects-with-a-changed-module-path/13405732 | 
 |         # Lets us override the imports that pickle uses when unpickling an object. | 
 |         # This is useful for maintaining BC if we change a module path that tensor instantiation relies on. | 
 |         def find_class(self, mod_name, name): | 
 |             if type(name) is str and 'Storage' in name: | 
 |                 try: | 
 |                     return StorageType(name) | 
 |                 except KeyError: | 
 |                     pass | 
 |             mod_name = load_module_mapping.get(mod_name, mod_name) | 
 |             return super().find_class(mod_name, name) | 
 |  | 
 |     # Load the data (which may in turn use `persistent_load` to load tensors) | 
 |     data_file = io.BytesIO(zip_file.get_record(pickle_file)) | 
 |  | 
 |     unpickler = UnpicklerWrapper(data_file, **pickle_load_args) | 
 |     unpickler.persistent_load = persistent_load | 
 |     result = unpickler.load() | 
 |  | 
 |     torch._utils._validate_loaded_sparse_tensors() | 
 |     torch._C._log_api_usage_metadata( | 
 |         "torch.load.metadata", {"serialization_id": zip_file.serialization_id()} | 
 |     ) | 
 |     return result | 
 |  | 
 |  | 
 | def _is_torchscript_zip(zip_file): | 
 |     return 'constants.pkl' in zip_file.get_all_records() |