| import difflib | 
 | import os | 
 | import io | 
 | import shutil | 
 | import struct | 
 | import sys | 
 | import torch | 
 | import tarfile | 
 | import tempfile | 
 | import warnings | 
 | from contextlib import closing, contextmanager | 
 | from ._utils import _import_dotted_name | 
 | from ._six import string_classes as _string_classes | 
 | from torch._utils_internal import get_source_lines_and_file | 
 | from torch.types import Storage | 
 | from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union, IO | 
 | import copyreg | 
 | import pickle | 
 | import pathlib | 
 |  | 
 | DEFAULT_PROTOCOL = 2 | 
 |  | 
 | LONG_SIZE = struct.Struct('=l').size | 
 | INT_SIZE = struct.Struct('=i').size | 
 | SHORT_SIZE = struct.Struct('=h').size | 
 |  | 
 | MAGIC_NUMBER = 0x1950a86a20f9469cfc6c | 
 | PROTOCOL_VERSION = 1001 | 
 | STORAGE_KEY_SEPARATOR = ',' | 
 |  | 
 | class SourceChangeWarning(Warning): | 
 |     pass | 
 |  | 
 |  | 
 | @contextmanager | 
 | def mkdtemp(): | 
 |     path = tempfile.mkdtemp() | 
 |     yield path | 
 |     shutil.rmtree(path) | 
 |  | 
 |  | 
 | _package_registry = [] | 
 |  | 
 |  | 
 | def _is_zipfile(f) -> bool: | 
 |     # This is a stricter implementation than zipfile.is_zipfile(). | 
 |     # zipfile.is_zipfile() is True if the magic number appears anywhere in the | 
 |     # binary. Since we expect the files here to be generated by torch.save or | 
 |     # torch.jit.save, it's safe to only check the start bytes and avoid | 
 |     # collisions and assume the zip has only 1 file. | 
 |     # See bugs.python.org/issue28494. | 
 |  | 
 |     # Read the first 4 bytes of the file | 
 |     read_bytes = [] | 
 |     start = f.tell() | 
 |  | 
 |     byte = f.read(1) | 
 |     while byte != "": | 
 |         read_bytes.append(byte) | 
 |         if len(read_bytes) == 4: | 
 |             break | 
 |         byte = f.read(1) | 
 |     f.seek(start) | 
 |  | 
 |     local_header_magic_number = [b'P', b'K', b'\x03', b'\x04'] | 
 |     return read_bytes == local_header_magic_number | 
 |  | 
 |  | 
 | def register_package(priority, tagger, deserializer): | 
 |     queue_elem = (priority, tagger, deserializer) | 
 |     _package_registry.append(queue_elem) | 
 |     _package_registry.sort() | 
 |  | 
 |  | 
 | def check_module_version_greater_or_equal(module, req_version_tuple, error_if_malformed=True): | 
 |     ''' | 
 |     Check if a module's version satisfies requirements | 
 |  | 
 |     Usually, a module's version string will be like 'x.y.z', which would be represented | 
 |     as a tuple (x, y, z), but sometimes it could be an unexpected format. If the version | 
 |     string does not match the given tuple's format up to the length of the tuple, then | 
 |     error and exit or emit a warning. | 
 |  | 
 |     Args: | 
 |         module: the module to check the version of | 
 |         req_version_tuple: tuple (usually of ints) representing the required version | 
 |         error_if_malformed: whether we should exit if module version string is malformed | 
 |  | 
 |     Returns: | 
 |         requirement_is_met: bool | 
 |     ''' | 
 |     try: | 
 |         version_strs = module.__version__.split('.') | 
 |         # Cast module version fields to match the types of the required version | 
 |         module_version = tuple( | 
 |             type(req_field)(version_strs[idx]) for idx, req_field in enumerate(req_version_tuple) | 
 |         ) | 
 |         requirement_is_met = module_version >= req_version_tuple | 
 |  | 
 |     except Exception as e: | 
 |         message = ( | 
 |             "'%s' module version string is malformed '%s' and cannot be compared" | 
 |             " with tuple %s" | 
 |         ) % ( | 
 |             module.__name__, module.__version__, str(req_version_tuple) | 
 |         ) | 
 |         if error_if_malformed: | 
 |             raise RuntimeError(message) from e | 
 |         else: | 
 |             warnings.warn(message + ', but continuing assuming that requirement is met') | 
 |             requirement_is_met = True | 
 |  | 
 |     return requirement_is_met | 
 |  | 
 |  | 
 | def _cpu_tag(obj): | 
 |     if type(obj).__module__ == 'torch': | 
 |         return 'cpu' | 
 |  | 
 |  | 
 | def _cuda_tag(obj): | 
 |     if type(obj).__module__ == 'torch.cuda': | 
 |         return 'cuda:' + str(obj.get_device()) | 
 |  | 
 |  | 
 | def _cpu_deserialize(obj, location): | 
 |     if location == 'cpu': | 
 |         return obj | 
 |  | 
 |  | 
 | def validate_cuda_device(location): | 
 |     device = torch.cuda._utils._get_device_index(location, True) | 
 |  | 
 |     if not torch.cuda.is_available(): | 
 |         raise RuntimeError('Attempting to deserialize object on a CUDA ' | 
 |                            'device but torch.cuda.is_available() is False. ' | 
 |                            'If you are running on a CPU-only machine, ' | 
 |                            'please use torch.load with map_location=torch.device(\'cpu\') ' | 
 |                            'to map your storages to the CPU.') | 
 |     device_count = torch.cuda.device_count() | 
 |     if device >= device_count: | 
 |         raise RuntimeError('Attempting to deserialize object on CUDA device ' | 
 |                            f'{device} but torch.cuda.device_count() is {device_count}. Please use ' | 
 |                            'torch.load with map_location to map your storages ' | 
 |                            'to an existing device.') | 
 |     return device | 
 |  | 
 |  | 
 | def _cuda_deserialize(obj, location): | 
 |     if location.startswith('cuda'): | 
 |         device = validate_cuda_device(location) | 
 |         if getattr(obj, "_torch_load_uninitialized", False): | 
 |             storage_type = getattr(torch.cuda, type(obj).__name__) | 
 |             with torch.cuda.device(device): | 
 |                 return storage_type(obj.size()) | 
 |         else: | 
 |             return obj.cuda(device) | 
 |  | 
 |  | 
 | register_package(10, _cpu_tag, _cpu_deserialize) | 
 | register_package(20, _cuda_tag, _cuda_deserialize) | 
 |  | 
 |  | 
 | def location_tag(storage: Storage): | 
 |     for _, tagger, _ in _package_registry: | 
 |         location = tagger(storage) | 
 |         if location: | 
 |             return location | 
 |     raise RuntimeError("don't know how to determine data location of " | 
 |                        + torch.typename(storage)) | 
 |  | 
 |  | 
 | def default_restore_location(storage, location): | 
 |     for _, _, fn in _package_registry: | 
 |         result = fn(storage, location) | 
 |         if result is not None: | 
 |             return result | 
 |     raise RuntimeError("don't know how to restore data location of " | 
 |                        + torch.typename(storage) + " (tagged with " | 
 |                        + location + ")") | 
 |  | 
 |  | 
 | def normalize_storage_type(storage_type): | 
 |     return getattr(torch, storage_type.__name__) | 
 |  | 
 |  | 
 | def storage_to_tensor_type(storage): | 
 |     storage_type = type(storage) | 
 |     module = _import_dotted_name(storage_type.__module__) | 
 |     return getattr(module, storage_type.__name__.replace('Storage', 'Tensor')) | 
 |  | 
 |  | 
 | def _is_path(name_or_buffer): | 
 |     return isinstance(name_or_buffer, str) or \ | 
 |         isinstance(name_or_buffer, pathlib.Path) | 
 |  | 
 |  | 
 | class _opener(object): | 
 |     def __init__(self, file_like): | 
 |         self.file_like = file_like | 
 |  | 
 |     def __enter__(self): | 
 |         return self.file_like | 
 |  | 
 |     def __exit__(self, *args): | 
 |         pass | 
 |  | 
 |  | 
 | class _open_file(_opener): | 
 |     def __init__(self, name, mode): | 
 |         super(_open_file, self).__init__(open(name, mode)) | 
 |  | 
 |     def __exit__(self, *args): | 
 |         self.file_like.close() | 
 |  | 
 |  | 
 | class _open_buffer_reader(_opener): | 
 |     def __init__(self, buffer): | 
 |         super(_open_buffer_reader, self).__init__(buffer) | 
 |         _check_seekable(buffer) | 
 |  | 
 |  | 
 | class _open_buffer_writer(_opener): | 
 |     def __exit__(self, *args): | 
 |         self.file_like.flush() | 
 |  | 
 |  | 
 | def _open_file_like(name_or_buffer, mode): | 
 |     if _is_path(name_or_buffer): | 
 |         return _open_file(name_or_buffer, mode) | 
 |     else: | 
 |         if 'w' in mode: | 
 |             return _open_buffer_writer(name_or_buffer) | 
 |         elif 'r' in mode: | 
 |             return _open_buffer_reader(name_or_buffer) | 
 |         else: | 
 |             raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}") | 
 |  | 
 |  | 
 | class _open_zipfile_reader(_opener): | 
 |     def __init__(self, name_or_buffer) -> None: | 
 |         super(_open_zipfile_reader, self).__init__(torch._C.PyTorchFileReader(name_or_buffer)) | 
 |  | 
 |  | 
 | class _open_zipfile_writer_file(_opener): | 
 |     def __init__(self, name) -> None: | 
 |         super(_open_zipfile_writer_file, self).__init__(torch._C.PyTorchFileWriter(str(name))) | 
 |  | 
 |     def __exit__(self, *args) -> None: | 
 |         self.file_like.write_end_of_file() | 
 |  | 
 |  | 
 | class _open_zipfile_writer_buffer(_opener): | 
 |     def __init__(self, buffer) -> None: | 
 |         self.buffer = buffer | 
 |         super(_open_zipfile_writer_buffer, self).__init__(torch._C.PyTorchFileWriter(buffer)) | 
 |  | 
 |     def __exit__(self, *args) -> None: | 
 |         self.file_like.write_end_of_file() | 
 |         self.buffer.flush() | 
 |  | 
 |  | 
 | def _open_zipfile_writer(name_or_buffer): | 
 |     container: Type[_opener] | 
 |     if _is_path(name_or_buffer): | 
 |         container = _open_zipfile_writer_file | 
 |     else: | 
 |         container = _open_zipfile_writer_buffer | 
 |     return container(name_or_buffer) | 
 |  | 
 |  | 
 | def _is_compressed_file(f) -> bool: | 
 |     compress_modules = ['gzip'] | 
 |     try: | 
 |         return f.__module__ in compress_modules | 
 |     except AttributeError: | 
 |         return False | 
 |  | 
 |  | 
 | def _should_read_directly(f): | 
 |     """ | 
 |     Checks if f is a file that should be read directly. It should be read | 
 |     directly if it is backed by a real file (has a fileno) and is not a | 
 |     a compressed file (e.g. gzip) | 
 |     """ | 
 |     if _is_compressed_file(f): | 
 |         return False | 
 |     try: | 
 |         return f.fileno() >= 0 | 
 |     except io.UnsupportedOperation: | 
 |         return False | 
 |     except AttributeError: | 
 |         return False | 
 |  | 
 |  | 
 | def _check_seekable(f) -> bool: | 
 |  | 
 |     def raise_err_msg(patterns, e): | 
 |         for p in patterns: | 
 |             if p in str(e): | 
 |                 msg = (str(e) + ". You can only torch.load from a file that is seekable." | 
 |                                 + " Please pre-load the data into a buffer like io.BytesIO and" | 
 |                                 + " try to load from it instead.") | 
 |                 raise type(e)(msg) | 
 |         raise e | 
 |  | 
 |     try: | 
 |         f.seek(f.tell()) | 
 |         return True | 
 |     except (io.UnsupportedOperation, AttributeError) as e: | 
 |         raise_err_msg(["seek", "tell"], e) | 
 |     return False | 
 |  | 
 | def _check_dill_version(pickle_module) -> None: | 
 |     '''Checks if using dill as the pickle module, and if so, checks if it is the correct version. | 
 |     If dill version is lower than 0.3.1, a ValueError is raised. | 
 |  | 
 |     Args: | 
 |         pickle_module: module used for pickling metadata and objects | 
 |  | 
 |     ''' | 
 |     if pickle_module.__name__ == 'dill': | 
 |         required_dill_version = (0, 3, 1) | 
 |         if not check_module_version_greater_or_equal(pickle_module, required_dill_version, False): | 
 |             raise ValueError(( | 
 |                 "'torch' supports dill >= %s, but you have dill %s." | 
 |                 " Please upgrade dill or switch to 'pickle'" | 
 |             ) % ( | 
 |                 '.'.join([str(num) for num in required_dill_version]), | 
 |                 pickle_module.__version__ | 
 |             )) | 
 |  | 
 | def save(obj, f: Union[str, os.PathLike, BinaryIO, IO[bytes]], | 
 |          pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) -> None: | 
 |     """Saves an object to a disk file. | 
 |  | 
 |     See also: `saving-loading-tensors` | 
 |  | 
 |     Args: | 
 |         obj: saved object | 
 |         f: a file-like object (has to implement write and flush) or a string or | 
 |            os.PathLike object containing a file name | 
 |         pickle_module: module used for pickling metadata and objects | 
 |         pickle_protocol: can be specified to override the default protocol | 
 |  | 
 |     .. note:: | 
 |         A common PyTorch convention is to save tensors using .pt file extension. | 
 |  | 
 |     .. note:: | 
 |         PyTorch preserves storage sharing across serialization. See | 
 |         `preserve-storage-sharing` for more details. | 
 |  | 
 |     .. note:: | 
 |         The 1.6 release of PyTorch switched ``torch.save`` to use a new | 
 |         zipfile-based file format. ``torch.load`` still retains the ability to | 
 |         load files in the old format. If for any reason you want ``torch.save`` | 
 |         to use the old format, pass the kwarg ``_use_new_zipfile_serialization=False``. | 
 |  | 
 |     Example: | 
 |         >>> # Save to file | 
 |         >>> x = torch.tensor([0, 1, 2, 3, 4]) | 
 |         >>> torch.save(x, 'tensor.pt') | 
 |         >>> # Save to io.BytesIO buffer | 
 |         >>> buffer = io.BytesIO() | 
 |         >>> torch.save(x, buffer) | 
 |     """ | 
 |     _check_dill_version(pickle_module) | 
 |  | 
 |     with _open_file_like(f, 'wb') as opened_file: | 
 |         if _use_new_zipfile_serialization: | 
 |             with _open_zipfile_writer(opened_file) as opened_zipfile: | 
 |                 _save(obj, opened_zipfile, pickle_module, pickle_protocol) | 
 |                 return | 
 |         _legacy_save(obj, opened_file, pickle_module, pickle_protocol) | 
 |  | 
 |  | 
 | def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None: | 
 |     import torch.nn as nn | 
 |     serialized_container_types = {} | 
 |     serialized_storages = {} | 
 |  | 
 |     def persistent_id(obj: Any) -> Optional[Tuple]: | 
 |         # FIXME: the docs say that persistent_id should only return a string | 
 |         # but torch store returns tuples. This works only in the binary protocol | 
 |         # see | 
 |         # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects | 
 |         # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 | 
 |         if isinstance(obj, type) and issubclass(obj, nn.Module): | 
 |             if obj in serialized_container_types: | 
 |                 return None | 
 |             serialized_container_types[obj] = True | 
 |             source_file = source = None | 
 |             try: | 
 |                 source_lines, _, source_file = get_source_lines_and_file(obj) | 
 |                 source = ''.join(source_lines) | 
 |             except Exception:  # saving the source is optional, so we can ignore any errors | 
 |                 warnings.warn("Couldn't retrieve source code for container of " | 
 |                               "type " + obj.__name__ + ". It won't be checked " | 
 |                               "for correctness upon loading.") | 
 |             return ('module', obj, source_file, source) | 
 |  | 
 |         elif torch.is_storage(obj): | 
 |             view_metadata: Optional[Tuple[str, int, int]] | 
 |             obj = cast(Storage, obj) | 
 |             storage_type = normalize_storage_type(type(obj)) | 
 |             # Offset is always 0, but we keep it for backwards compatibility | 
 |             # with the old serialization format (which supported storage views) | 
 |             offset = 0 | 
 |             obj_key = str(obj._cdata) | 
 |             location = location_tag(obj) | 
 |             serialized_storages[obj_key] = obj | 
 |             is_view = obj._cdata != obj._cdata | 
 |             if is_view: | 
 |                 view_metadata = (str(obj._cdata), offset, obj.size()) | 
 |             else: | 
 |                 view_metadata = None | 
 |  | 
 |             return ('storage', | 
 |                     storage_type, | 
 |                     obj_key, | 
 |                     location, | 
 |                     obj.size(), | 
 |                     view_metadata) | 
 |         return None | 
 |  | 
 |     sys_info = dict( | 
 |         protocol_version=PROTOCOL_VERSION, | 
 |         little_endian=sys.byteorder == 'little', | 
 |         type_sizes=dict( | 
 |             short=SHORT_SIZE, | 
 |             int=INT_SIZE, | 
 |             long=LONG_SIZE, | 
 |         ), | 
 |     ) | 
 |  | 
 |     pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol) | 
 |     pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol) | 
 |     pickle_module.dump(sys_info, f, protocol=pickle_protocol) | 
 |     pickler = pickle_module.Pickler(f, protocol=pickle_protocol) | 
 |     pickler.persistent_id = persistent_id | 
 |     pickler.dump(obj) | 
 |  | 
 |     serialized_storage_keys = sorted(serialized_storages.keys()) | 
 |     pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol) | 
 |     f.flush() | 
 |     for key in serialized_storage_keys: | 
 |         serialized_storages[key]._write_file(f, _should_read_directly(f), True) | 
 |  | 
 |  | 
 | def _save(obj, zip_file, pickle_module, pickle_protocol): | 
 |     serialized_storages = {} | 
 |  | 
 |     def persistent_id(obj): | 
 |         # FIXME: the docs say that persistent_id should only return a string | 
 |         # but torch store returns tuples. This works only in the binary protocol | 
 |         # see | 
 |         # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects | 
 |         # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 | 
 |         if torch.is_storage(obj): | 
 |             storage_type = normalize_storage_type(type(obj)) | 
 |             obj_key = str(obj._cdata) | 
 |             location = location_tag(obj) | 
 |             serialized_storages[obj_key] = obj | 
 |  | 
 |             return ('storage', | 
 |                     storage_type, | 
 |                     obj_key, | 
 |                     location, | 
 |                     obj.size()) | 
 |         return None | 
 |  | 
 |     # Write the pickle data for `obj` | 
 |     data_buf = io.BytesIO() | 
 |     pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol) | 
 |     pickler.persistent_id = persistent_id | 
 |     pickler.dump(obj) | 
 |     data_value = data_buf.getvalue() | 
 |     zip_file.write_record('data.pkl', data_value, len(data_value)) | 
 |  | 
 |     # Write each tensor to a file named tensor/the_tensor_key in the zip archive | 
 |     for key in sorted(serialized_storages.keys()): | 
 |         name = f'data/{key}' | 
 |         storage = serialized_storages[key] | 
 |         # given that we copy things around anyway, we might use storage.cpu() | 
 |         # this means to that to get tensors serialized, you need to implement | 
 |         # .cpu() on the underlying Storage | 
 |         if storage.device.type != 'cpu': | 
 |             storage = storage.cpu() | 
 |         # Now that it is on the CPU we can directly copy it into the zip file | 
 |         num_bytes = storage.size() * storage.element_size() | 
 |         zip_file.write_record(name, storage.data_ptr(), num_bytes) | 
 |  | 
 |  | 
 | def load(f, map_location=None, pickle_module=pickle, **pickle_load_args): | 
 |     """Loads an object saved with :func:`torch.save` from a file. | 
 |  | 
 |     :func:`torch.load` uses Python's unpickling facilities but treats storages, | 
 |     which underlie tensors, specially. They are first deserialized on the | 
 |     CPU and are then moved to the device they were saved from. If this fails | 
 |     (e.g. because the run time system doesn't have certain devices), an exception | 
 |     is raised. However, storages can be dynamically remapped to an alternative | 
 |     set of devices using the :attr:`map_location` argument. | 
 |  | 
 |     If :attr:`map_location` is a callable, it will be called once for each serialized | 
 |     storage with two arguments: storage and location. The storage argument | 
 |     will be the initial deserialization of the storage, residing on the CPU. | 
 |     Each serialized storage has a location tag associated with it which | 
 |     identifies the device it was saved from, and this tag is the second | 
 |     argument passed to :attr:`map_location`. The builtin location tags are ``'cpu'`` | 
 |     for CPU tensors and ``'cuda:device_id'`` (e.g. ``'cuda:2'``) for CUDA tensors. | 
 |     :attr:`map_location` should return either ``None`` or a storage. If | 
 |     :attr:`map_location` returns a storage, it will be used as the final deserialized | 
 |     object, already moved to the right device. Otherwise, :func:`torch.load` will | 
 |     fall back to the default behavior, as if :attr:`map_location` wasn't specified. | 
 |  | 
 |     If :attr:`map_location` is a :class:`torch.device` object or a string containing | 
 |     a device tag, it indicates the location where all tensors should be loaded. | 
 |  | 
 |     Otherwise, if :attr:`map_location` is a dict, it will be used to remap location tags | 
 |     appearing in the file (keys), to ones that specify where to put the | 
 |     storages (values). | 
 |  | 
 |     User extensions can register their own location tags and tagging and | 
 |     deserialization methods using :func:`torch.serialization.register_package`. | 
 |  | 
 |     Args: | 
 |         f: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`), | 
 |             or a string or os.PathLike object containing a file name | 
 |         map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage | 
 |             locations | 
 |         pickle_module: module used for unpickling metadata and objects (has to | 
 |             match the :attr:`pickle_module` used to serialize file) | 
 |         pickle_load_args: (Python 3 only) optional keyword arguments passed over to | 
 |             :func:`pickle_module.load` and :func:`pickle_module.Unpickler`, e.g., | 
 |             :attr:`errors=...`. | 
 |  | 
 |     .. warning:: | 
 |         :func:`torch.load()` uses ``pickle`` module implicitly, which is known to be insecure. | 
 |         It is possible to construct malicious pickle data which will execute arbitrary code | 
 |         during unpickling. Never load data that could have come from an untrusted | 
 |         source, or that could have been tampered with. **Only load data you trust**. | 
 |  | 
 |     .. note:: | 
 |         When you call :func:`torch.load()` on a file which contains GPU tensors, those tensors | 
 |         will be loaded to GPU by default. You can call ``torch.load(.., map_location='cpu')`` | 
 |         and then :meth:`load_state_dict` to avoid GPU RAM surge when loading a model checkpoint. | 
 |  | 
 |     .. note:: | 
 |         By default, we decode byte strings as ``utf-8``.  This is to avoid a common error | 
 |         case ``UnicodeDecodeError: 'ascii' codec can't decode byte 0x...`` | 
 |         when loading files saved by Python 2 in Python 3.  If this default | 
 |         is incorrect, you may use an extra :attr:`encoding` keyword argument to specify how | 
 |         these objects should be loaded, e.g., :attr:`encoding='latin1'` decodes them | 
 |         to strings using ``latin1`` encoding, and :attr:`encoding='bytes'` keeps them | 
 |         as byte arrays which can be decoded later with ``byte_array.decode(...)``. | 
 |  | 
 |     Example: | 
 |         >>> torch.load('tensors.pt') | 
 |         # Load all tensors onto the CPU | 
 |         >>> torch.load('tensors.pt', map_location=torch.device('cpu')) | 
 |         # Load all tensors onto the CPU, using a function | 
 |         >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage) | 
 |         # Load all tensors onto GPU 1 | 
 |         >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1)) | 
 |         # Map tensors from GPU 1 to GPU 0 | 
 |         >>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'}) | 
 |         # Load tensor from io.BytesIO object | 
 |         >>> with open('tensor.pt', 'rb') as f: | 
 |         ...     buffer = io.BytesIO(f.read()) | 
 |         >>> torch.load(buffer) | 
 |         # Load a module with 'ascii' encoding for unpickling | 
 |         >>> torch.load('module.pt', encoding='ascii') | 
 |     """ | 
 |     _check_dill_version(pickle_module) | 
 |  | 
 |     if 'encoding' not in pickle_load_args.keys(): | 
 |         pickle_load_args['encoding'] = 'utf-8' | 
 |  | 
 |     with _open_file_like(f, 'rb') as opened_file: | 
 |         if _is_zipfile(opened_file): | 
 |             # The zipfile reader is going to advance the current file position. | 
 |             # If we want to actually tail call to torch.jit.load, we need to | 
 |             # reset back to the original position. | 
 |             orig_position = opened_file.tell() | 
 |             with _open_zipfile_reader(opened_file) as opened_zipfile: | 
 |                 if _is_torchscript_zip(opened_zipfile): | 
 |                     warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive" | 
 |                                   " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to" | 
 |                                   " silence this warning)", UserWarning) | 
 |                     opened_file.seek(orig_position) | 
 |                     return torch.jit.load(opened_file) | 
 |                 return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args) | 
 |         return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args) | 
 |  | 
 |  | 
 | # Register pickling support for layout instances such as | 
 | # torch.sparse_coo, etc | 
 | def _get_layout(name): | 
 |     """Get layout extension object from its string representation. | 
 |     """ | 
 |     cache = _get_layout.cache   # type: ignore[attr-defined] | 
 |     if not cache: | 
 |         for v in torch.__dict__.values(): | 
 |             if isinstance(v, torch.layout): | 
 |                 cache[str(v)] = v | 
 |     return cache[name] | 
 |  | 
 | # There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087 | 
 | _get_layout.cache = {}   # type: ignore[attr-defined] | 
 | copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),))) | 
 |  | 
 |  | 
 | def _legacy_load(f, map_location, pickle_module, **pickle_load_args): | 
 |     deserialized_objects: Dict[int, Any] = {} | 
 |  | 
 |     restore_location = _get_restore_location(map_location) | 
 |  | 
 |     def _check_container_source(container_type, source_file, original_source): | 
 |         try: | 
 |             current_source = ''.join(get_source_lines_and_file(container_type)[0]) | 
 |         except Exception:  # saving the source is optional, so we can ignore any errors | 
 |             warnings.warn("Couldn't retrieve source code for container of " | 
 |                           "type " + container_type.__name__ + ". It won't be checked " | 
 |                           "for correctness upon loading.") | 
 |             return | 
 |         if original_source != current_source: | 
 |             if container_type.dump_patches: | 
 |                 file_name = container_type.__name__ + '.patch' | 
 |                 diff = difflib.unified_diff(current_source.split('\n'), | 
 |                                             original_source.split('\n'), | 
 |                                             source_file, | 
 |                                             source_file, lineterm="") | 
 |                 lines = '\n'.join(diff) | 
 |                 try: | 
 |                     with open(file_name, 'a+') as f: | 
 |                         file_size = f.seek(0, 2) | 
 |                         f.seek(0) | 
 |                         if file_size == 0: | 
 |                             f.write(lines) | 
 |                         elif file_size != len(lines) or f.read() != lines: | 
 |                             raise IOError | 
 |                     msg = ("Saved a reverse patch to " + file_name + ". " | 
 |                            "Run `patch -p0 < " + file_name + "` to revert your " | 
 |                            "changes.") | 
 |                 except IOError: | 
 |                     msg = ("Tried to save a patch, but couldn't create a " | 
 |                            "writable file " + file_name + ". Make sure it " | 
 |                            "doesn't exist and your working directory is " | 
 |                            "writable.") | 
 |             else: | 
 |                 msg = ("you can retrieve the original source code by " | 
 |                        "accessing the object's source attribute or set " | 
 |                        "`torch.nn.Module.dump_patches = True` and use the " | 
 |                        "patch tool to revert the changes.") | 
 |             msg = f"source code of class '{torch.typename(container_type)}' has changed. {msg}" | 
 |             warnings.warn(msg, SourceChangeWarning) | 
 |  | 
 |     def legacy_load(f): | 
 |         deserialized_objects: Dict[int, Any] = {} | 
 |  | 
 |         def persistent_load(saved_id): | 
 |             if isinstance(saved_id, tuple): | 
 |                 # Ignore containers that don't have any sources saved | 
 |                 if all(saved_id[1:]): | 
 |                     _check_container_source(*saved_id) | 
 |                 return saved_id[0] | 
 |             return deserialized_objects[int(saved_id)] | 
 |  | 
 |         with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \ | 
 |                 mkdtemp() as tmpdir: | 
 |  | 
 |             tar.extract('storages', path=tmpdir) | 
 |             with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f: | 
 |                 num_storages = pickle_module.load(f, **pickle_load_args) | 
 |                 for i in range(num_storages): | 
 |                     args = pickle_module.load(f, **pickle_load_args) | 
 |                     key, location, storage_type = args | 
 |                     obj = storage_type._new_with_file(f) | 
 |                     obj = restore_location(obj, location) | 
 |                     deserialized_objects[key] = obj | 
 |  | 
 |                 storage_views = pickle_module.load(f, **pickle_load_args) | 
 |                 for target_cdata, root_cdata, offset, size in storage_views: | 
 |                     root = deserialized_objects[root_cdata] | 
 |                     deserialized_objects[target_cdata] = root[offset:offset + size] | 
 |  | 
 |             tar.extract('tensors', path=tmpdir) | 
 |             with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f: | 
 |                 num_tensors = pickle_module.load(f, **pickle_load_args) | 
 |                 for _ in range(num_tensors): | 
 |                     args = pickle_module.load(f, **pickle_load_args) | 
 |                     key, storage_id, original_tensor_type = args | 
 |                     storage = deserialized_objects[storage_id] | 
 |                     tensor_type = storage_to_tensor_type(storage) | 
 |                     ndim, = struct.unpack('<i', f.read(4)) | 
 |                     # skip next 4 bytes; legacy encoding treated ndim as 8 bytes | 
 |                     f.read(4) | 
 |                     size = struct.unpack(f'<{ndim}q', f.read(8 * ndim)) | 
 |                     stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim)) | 
 |                     storage_offset, = struct.unpack('<q', f.read(8)) | 
 |                     tensor = tensor_type().set_(storage, storage_offset, size, stride) | 
 |                     deserialized_objects[key] = tensor | 
 |  | 
 |             pickle_file = tar.extractfile('pickle') | 
 |             unpickler = pickle_module.Unpickler(pickle_file, **pickle_load_args) | 
 |             unpickler.persistent_load = persistent_load | 
 |             result = unpickler.load() | 
 |             return result | 
 |  | 
 |     deserialized_objects = {} | 
 |  | 
 |     def persistent_load(saved_id): | 
 |         assert isinstance(saved_id, tuple) | 
 |         typename = _maybe_decode_ascii(saved_id[0]) | 
 |         data = saved_id[1:] | 
 |  | 
 |         if typename == 'module': | 
 |             # Ignore containers that don't have any sources saved | 
 |             if all(data[1:]): | 
 |                 _check_container_source(*data) | 
 |             return data[0] | 
 |         elif typename == 'storage': | 
 |             data_type, root_key, location, size, view_metadata = data | 
 |             location = _maybe_decode_ascii(location) | 
 |             if root_key not in deserialized_objects: | 
 |                 obj = data_type(size) | 
 |                 obj._torch_load_uninitialized = True | 
 |                 deserialized_objects[root_key] = restore_location(obj, location) | 
 |             storage = deserialized_objects[root_key] | 
 |             if view_metadata is not None: | 
 |                 view_key, offset, view_size = view_metadata | 
 |                 if view_key not in deserialized_objects: | 
 |                     deserialized_objects[view_key] = storage[offset:offset + view_size] | 
 |                 return deserialized_objects[view_key] | 
 |             else: | 
 |                 return storage | 
 |         else: | 
 |             raise RuntimeError("Unknown saved id type: %s" % saved_id[0]) | 
 |  | 
 |     _check_seekable(f) | 
 |     f_should_read_directly = _should_read_directly(f) | 
 |  | 
 |     if f_should_read_directly and f.tell() == 0: | 
 |         # legacy_load requires that f has fileno() | 
 |         # only if offset is zero we can attempt the legacy tar file loader | 
 |         try: | 
 |             return legacy_load(f) | 
 |         except tarfile.TarError: | 
 |             if _is_zipfile(f): | 
 |                 # .zip is used for torch.jit.save and will throw an un-pickling error here | 
 |                 raise RuntimeError( | 
 |                     f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)") from None | 
 |             # if not a tarfile, reset file offset and proceed | 
 |             f.seek(0) | 
 |  | 
 |     if not hasattr(f, 'readinto') and (3, 8, 0) <= sys.version_info < (3, 8, 2): | 
 |         raise RuntimeError( | 
 |             "torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. " | 
 |             f"Received object of type \"{type(f)}\". Please update to Python 3.8.2 or newer to restore this " | 
 |             "functionality.") | 
 |  | 
 |     magic_number = pickle_module.load(f, **pickle_load_args) | 
 |     if magic_number != MAGIC_NUMBER: | 
 |         raise RuntimeError("Invalid magic number; corrupt file?") | 
 |     protocol_version = pickle_module.load(f, **pickle_load_args) | 
 |     if protocol_version != PROTOCOL_VERSION: | 
 |         raise RuntimeError("Invalid protocol version: %s" % protocol_version) | 
 |  | 
 |     _sys_info = pickle_module.load(f, **pickle_load_args) | 
 |     unpickler = pickle_module.Unpickler(f, **pickle_load_args) | 
 |     unpickler.persistent_load = persistent_load | 
 |     result = unpickler.load() | 
 |  | 
 |     deserialized_storage_keys = pickle_module.load(f, **pickle_load_args) | 
 |  | 
 |     offset = f.tell() if f_should_read_directly else None | 
 |     for key in deserialized_storage_keys: | 
 |         assert key in deserialized_objects | 
 |         deserialized_objects[key]._set_from_file(f, offset, f_should_read_directly) | 
 |         if offset is not None: | 
 |             offset = f.tell() | 
 |  | 
 |     torch._utils._validate_loaded_sparse_tensors() | 
 |  | 
 |     return result | 
 |  | 
 |  | 
 | def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str: | 
 |     # When using encoding='bytes' in Py3, some **internal** keys stored as | 
 |     # strings in Py2 are loaded as bytes. This function decodes them with | 
 |     # ascii encoding, one that Py3 uses by default. | 
 |     # | 
 |     # NOTE: This should only be used on internal keys (e.g., `typename` and | 
 |     #       `location` in `persistent_load` below! | 
 |     if isinstance(bytes_str, bytes): | 
 |         return bytes_str.decode('ascii') | 
 |     return bytes_str | 
 |  | 
 |  | 
 | def _get_restore_location(map_location): | 
 |     if map_location is None: | 
 |         restore_location = default_restore_location | 
 |     elif isinstance(map_location, dict): | 
 |         def restore_location(storage, location): | 
 |             location = map_location.get(location, location) | 
 |             return default_restore_location(storage, location) | 
 |     elif isinstance(map_location, _string_classes): | 
 |         def restore_location(storage, location): | 
 |             return default_restore_location(storage, map_location) | 
 |     elif isinstance(map_location, torch.device): | 
 |         def restore_location(storage, location): | 
 |             return default_restore_location(storage, str(map_location)) | 
 |     else: | 
 |         def restore_location(storage, location): | 
 |             result = map_location(storage, location) | 
 |             if result is None: | 
 |                 result = default_restore_location(storage, location) | 
 |             return result | 
 |     return restore_location | 
 |  | 
 |  | 
 | def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', **pickle_load_args): | 
 |     restore_location = _get_restore_location(map_location) | 
 |  | 
 |     loaded_storages = {} | 
 |  | 
 |     def load_tensor(data_type, size, key, location): | 
 |         name = f'data/{key}' | 
 |         dtype = data_type(0).dtype | 
 |  | 
 |         storage = zip_file.get_storage_from_record(name, size, dtype).storage() | 
 |         loaded_storages[key] = restore_location(storage, location) | 
 |  | 
 |     def persistent_load(saved_id): | 
 |         assert isinstance(saved_id, tuple) | 
 |         typename = _maybe_decode_ascii(saved_id[0]) | 
 |         data = saved_id[1:] | 
 |  | 
 |         assert typename == 'storage', \ | 
 |             f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" | 
 |         data_type, key, location, size = data | 
 |         if key not in loaded_storages: | 
 |             load_tensor(data_type, size, key, _maybe_decode_ascii(location)) | 
 |         storage = loaded_storages[key] | 
 |         return storage | 
 |  | 
 |     # Load the data (which may in turn use `persistent_load` to load tensors) | 
 |     data_file = io.BytesIO(zip_file.get_record(pickle_file)) | 
 |     unpickler = pickle_module.Unpickler(data_file, **pickle_load_args) | 
 |     unpickler.persistent_load = persistent_load | 
 |     result = unpickler.load() | 
 |  | 
 |     torch._utils._validate_loaded_sparse_tensors() | 
 |  | 
 |     return result | 
 |  | 
 |  | 
 | def _is_torchscript_zip(zip_file): | 
 |     return 'constants.pkl' in zip_file.get_all_records() |