Sam Gross | e3e786e | 2016-11-03 16:29:14 -0400 | [diff] [blame] | 1 | import difflib |
Zsolt Dollenstein | b004307 | 2021-08-12 10:56:55 -0700 | [diff] [blame] | 2 | import os |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 3 | import io |
Adam Paszke | 686e8d3 | 2016-08-22 22:11:50 -0400 | [diff] [blame] | 4 | import shutil |
| 5 | import struct |
Sam Gross | e3e786e | 2016-11-03 16:29:14 -0400 | [diff] [blame] | 6 | import sys |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 7 | import torch |
Sam Gross | e3e786e | 2016-11-03 16:29:14 -0400 | [diff] [blame] | 8 | import tarfile |
| 9 | import tempfile |
| 10 | import warnings |
Adam Paszke | 686e8d3 | 2016-08-22 22:11:50 -0400 | [diff] [blame] | 11 | from contextlib import closing, contextmanager |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 12 | from ._utils import _import_dotted_name |
| 13 | from ._six import string_classes as _string_classes |
Zhengxu Chen | e62189a | 2021-08-05 14:19:56 -0700 | [diff] [blame] | 14 | from torch._sources import get_source_lines_and_file |
Nikita Shulga | 591fffc | 2020-07-02 07:09:21 -0700 | [diff] [blame] | 15 | from torch.types import Storage |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 16 | from torch.storage import _get_dtype_from_pickle_storage_type |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 17 | from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union, IO |
| 18 | import copyreg |
| 19 | import pickle |
| 20 | import pathlib |
Adam Paszke | 686e8d3 | 2016-08-22 22:11:50 -0400 | [diff] [blame] | 21 | |
Adam Paszke | 75579fc | 2016-08-23 07:52:58 -0700 | [diff] [blame] | 22 | DEFAULT_PROTOCOL = 2 |
| 23 | |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 24 | LONG_SIZE = struct.Struct('=l').size |
| 25 | INT_SIZE = struct.Struct('=i').size |
| 26 | SHORT_SIZE = struct.Struct('=h').size |
Adam Paszke | 686e8d3 | 2016-08-22 22:11:50 -0400 | [diff] [blame] | 27 | |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 28 | MAGIC_NUMBER = 0x1950a86a20f9469cfc6c |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 29 | PROTOCOL_VERSION = 1001 |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 30 | STORAGE_KEY_SEPARATOR = ',' |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 31 | |
Sam Gross | e3e786e | 2016-11-03 16:29:14 -0400 | [diff] [blame] | 32 | class SourceChangeWarning(Warning): |
| 33 | pass |
| 34 | |
| 35 | |
Adam Paszke | 686e8d3 | 2016-08-22 22:11:50 -0400 | [diff] [blame] | 36 | @contextmanager |
| 37 | def mkdtemp(): |
| 38 | path = tempfile.mkdtemp() |
| 39 | yield path |
| 40 | shutil.rmtree(path) |
| 41 | |
| 42 | |
Adam Paszke | 0c9670d | 2016-10-04 11:33:00 -0700 | [diff] [blame] | 43 | _package_registry = [] |
| 44 | |
| 45 | |
Nikita Shulga | 591fffc | 2020-07-02 07:09:21 -0700 | [diff] [blame] | 46 | def _is_zipfile(f) -> bool: |
davidriazati | 7a921ba | 2019-08-30 16:43:45 -0700 | [diff] [blame] | 47 | # This is a stricter implementation than zipfile.is_zipfile(). |
| 48 | # zipfile.is_zipfile() is True if the magic number appears anywhere in the |
| 49 | # binary. Since we expect the files here to be generated by torch.save or |
| 50 | # torch.jit.save, it's safe to only check the start bytes and avoid |
davidriazati | 74ce3a0 | 2020-02-05 15:30:21 -0800 | [diff] [blame] | 51 | # collisions and assume the zip has only 1 file. |
| 52 | # See bugs.python.org/issue28494. |
davidriazati | 7a921ba | 2019-08-30 16:43:45 -0700 | [diff] [blame] | 53 | |
| 54 | # Read the first 4 bytes of the file |
| 55 | read_bytes = [] |
| 56 | start = f.tell() |
| 57 | |
davidriazati | 7a921ba | 2019-08-30 16:43:45 -0700 | [diff] [blame] | 58 | byte = f.read(1) |
| 59 | while byte != "": |
| 60 | read_bytes.append(byte) |
| 61 | if len(read_bytes) == 4: |
| 62 | break |
| 63 | byte = f.read(1) |
| 64 | f.seek(start) |
| 65 | |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 66 | local_header_magic_number = [b'P', b'K', b'\x03', b'\x04'] |
davidriazati | 74ce3a0 | 2020-02-05 15:30:21 -0800 | [diff] [blame] | 67 | return read_bytes == local_header_magic_number |
davidriazati | 7a921ba | 2019-08-30 16:43:45 -0700 | [diff] [blame] | 68 | |
| 69 | |
Adam Paszke | 0c9670d | 2016-10-04 11:33:00 -0700 | [diff] [blame] | 70 | def register_package(priority, tagger, deserializer): |
| 71 | queue_elem = (priority, tagger, deserializer) |
| 72 | _package_registry.append(queue_elem) |
| 73 | _package_registry.sort() |
| 74 | |
| 75 | |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 76 | def check_module_version_greater_or_equal(module, req_version_tuple, error_if_malformed=True): |
| 77 | ''' |
Kurt Mohler | 3694749 | 2019-12-18 08:03:39 -0800 | [diff] [blame] | 78 | Check if a module's version satisfies requirements |
| 79 | |
| 80 | Usually, a module's version string will be like 'x.y.z', which would be represented |
| 81 | as a tuple (x, y, z), but sometimes it could be an unexpected format. If the version |
| 82 | string does not match the given tuple's format up to the length of the tuple, then |
| 83 | error and exit or emit a warning. |
| 84 | |
| 85 | Args: |
| 86 | module: the module to check the version of |
| 87 | req_version_tuple: tuple (usually of ints) representing the required version |
| 88 | error_if_malformed: whether we should exit if module version string is malformed |
| 89 | |
| 90 | Returns: |
| 91 | requirement_is_met: bool |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 92 | ''' |
Kurt Mohler | 3694749 | 2019-12-18 08:03:39 -0800 | [diff] [blame] | 93 | try: |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 94 | version_strs = module.__version__.split('.') |
Kurt Mohler | 3694749 | 2019-12-18 08:03:39 -0800 | [diff] [blame] | 95 | # Cast module version fields to match the types of the required version |
| 96 | module_version = tuple( |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 97 | type(req_field)(version_strs[idx]) for idx, req_field in enumerate(req_version_tuple) |
Kurt Mohler | 3694749 | 2019-12-18 08:03:39 -0800 | [diff] [blame] | 98 | ) |
| 99 | requirement_is_met = module_version >= req_version_tuple |
| 100 | |
| 101 | except Exception as e: |
| 102 | message = ( |
| 103 | "'%s' module version string is malformed '%s' and cannot be compared" |
| 104 | " with tuple %s" |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 105 | ) % ( |
| 106 | module.__name__, module.__version__, str(req_version_tuple) |
| 107 | ) |
Kurt Mohler | 3694749 | 2019-12-18 08:03:39 -0800 | [diff] [blame] | 108 | if error_if_malformed: |
Akihiro Nitta | f17d7a5 | 2020-08-31 19:28:48 -0700 | [diff] [blame] | 109 | raise RuntimeError(message) from e |
Kurt Mohler | 3694749 | 2019-12-18 08:03:39 -0800 | [diff] [blame] | 110 | else: |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 111 | warnings.warn(message + ', but continuing assuming that requirement is met') |
Kurt Mohler | 3694749 | 2019-12-18 08:03:39 -0800 | [diff] [blame] | 112 | requirement_is_met = True |
| 113 | |
| 114 | return requirement_is_met |
| 115 | |
| 116 | |
Adam Paszke | 0c9670d | 2016-10-04 11:33:00 -0700 | [diff] [blame] | 117 | def _cpu_tag(obj): |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 118 | if type(obj).__module__ == 'torch': |
| 119 | return 'cpu' |
Adam Paszke | 0c9670d | 2016-10-04 11:33:00 -0700 | [diff] [blame] | 120 | |
| 121 | |
| 122 | def _cuda_tag(obj): |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 123 | if type(obj).__module__ == 'torch.cuda': |
| 124 | return 'cuda:' + str(obj.get_device()) |
Adam Paszke | 0c9670d | 2016-10-04 11:33:00 -0700 | [diff] [blame] | 125 | |
| 126 | |
| 127 | def _cpu_deserialize(obj, location): |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 128 | if location == 'cpu': |
Adam Paszke | 0c9670d | 2016-10-04 11:33:00 -0700 | [diff] [blame] | 129 | return obj |
| 130 | |
| 131 | |
Lu Fang | e0f6867 | 2018-12-03 14:07:50 -0800 | [diff] [blame] | 132 | def validate_cuda_device(location): |
Sameer Deshmukh | 2f5eefe | 2020-01-07 10:27:47 -0800 | [diff] [blame] | 133 | device = torch.cuda._utils._get_device_index(location, True) |
Lu Fang | e0f6867 | 2018-12-03 14:07:50 -0800 | [diff] [blame] | 134 | |
| 135 | if not torch.cuda.is_available(): |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 136 | raise RuntimeError('Attempting to deserialize object on a CUDA ' |
| 137 | 'device but torch.cuda.is_available() is False. ' |
| 138 | 'If you are running on a CPU-only machine, ' |
| 139 | 'please use torch.load with map_location=torch.device(\'cpu\') ' |
| 140 | 'to map your storages to the CPU.') |
Nikita Shulga | 0c01f13 | 2020-09-04 07:36:47 -0700 | [diff] [blame] | 141 | device_count = torch.cuda.device_count() |
| 142 | if device >= device_count: |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 143 | raise RuntimeError('Attempting to deserialize object on CUDA device ' |
| 144 | f'{device} but torch.cuda.device_count() is {device_count}. Please use ' |
| 145 | 'torch.load with map_location to map your storages ' |
| 146 | 'to an existing device.') |
Lu Fang | e0f6867 | 2018-12-03 14:07:50 -0800 | [diff] [blame] | 147 | return device |
| 148 | |
| 149 | |
Adam Paszke | 0c9670d | 2016-10-04 11:33:00 -0700 | [diff] [blame] | 150 | def _cuda_deserialize(obj, location): |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 151 | if location.startswith('cuda'): |
Lu Fang | e0f6867 | 2018-12-03 14:07:50 -0800 | [diff] [blame] | 152 | device = validate_cuda_device(location) |
Luca Wehrstedt | 29f4f8f | 2019-02-21 01:24:56 -0800 | [diff] [blame] | 153 | if getattr(obj, "_torch_load_uninitialized", False): |
| 154 | storage_type = getattr(torch.cuda, type(obj).__name__) |
| 155 | with torch.cuda.device(device): |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 156 | return storage_type(obj.nbytes()) |
Luca Wehrstedt | 29f4f8f | 2019-02-21 01:24:56 -0800 | [diff] [blame] | 157 | else: |
| 158 | return obj.cuda(device) |
Adam Paszke | 0c9670d | 2016-10-04 11:33:00 -0700 | [diff] [blame] | 159 | |
| 160 | |
| 161 | register_package(10, _cpu_tag, _cpu_deserialize) |
| 162 | register_package(20, _cuda_tag, _cuda_deserialize) |
| 163 | |
| 164 | |
Kurt Mohler | 8e7fe87 | 2022-02-15 15:43:57 -0800 | [diff] [blame] | 165 | def location_tag(storage: Union[Storage, torch.storage._TypedStorage]): |
Adam Paszke | 0c9670d | 2016-10-04 11:33:00 -0700 | [diff] [blame] | 166 | for _, tagger, _ in _package_registry: |
| 167 | location = tagger(storage) |
| 168 | if location: |
| 169 | return location |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 170 | raise RuntimeError("don't know how to determine data location of " |
| 171 | + torch.typename(storage)) |
Adam Paszke | 0c9670d | 2016-10-04 11:33:00 -0700 | [diff] [blame] | 172 | |
| 173 | |
| 174 | def default_restore_location(storage, location): |
| 175 | for _, _, fn in _package_registry: |
| 176 | result = fn(storage, location) |
| 177 | if result is not None: |
| 178 | return result |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 179 | raise RuntimeError("don't know how to restore data location of " |
| 180 | + torch.typename(storage) + " (tagged with " |
| 181 | + location + ")") |
Adam Paszke | 0c9670d | 2016-10-04 11:33:00 -0700 | [diff] [blame] | 182 | |
| 183 | |
| 184 | def normalize_storage_type(storage_type): |
| 185 | return getattr(torch, storage_type.__name__) |
| 186 | |
| 187 | |
| 188 | def storage_to_tensor_type(storage): |
| 189 | storage_type = type(storage) |
| 190 | module = _import_dotted_name(storage_type.__module__) |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 191 | return getattr(module, storage_type.__name__.replace('Storage', 'Tensor')) |
Adam Paszke | 0c9670d | 2016-10-04 11:33:00 -0700 | [diff] [blame] | 192 | |
| 193 | |
Your Name | bfedace | 2019-11-14 13:35:37 -0800 | [diff] [blame] | 194 | def _is_path(name_or_buffer): |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 195 | return isinstance(name_or_buffer, str) or \ |
| 196 | isinstance(name_or_buffer, pathlib.Path) |
Your Name | fff4f16 | 2019-11-06 18:40:10 -0800 | [diff] [blame] | 197 | |
Your Name | fff4f16 | 2019-11-06 18:40:10 -0800 | [diff] [blame] | 198 | |
Your Name | bfedace | 2019-11-14 13:35:37 -0800 | [diff] [blame] | 199 | class _opener(object): |
| 200 | def __init__(self, file_like): |
| 201 | self.file_like = file_like |
Your Name | fff4f16 | 2019-11-06 18:40:10 -0800 | [diff] [blame] | 202 | |
| 203 | def __enter__(self): |
Your Name | fff4f16 | 2019-11-06 18:40:10 -0800 | [diff] [blame] | 204 | return self.file_like |
| 205 | |
| 206 | def __exit__(self, *args): |
Your Name | bfedace | 2019-11-14 13:35:37 -0800 | [diff] [blame] | 207 | pass |
| 208 | |
| 209 | |
| 210 | class _open_file(_opener): |
| 211 | def __init__(self, name, mode): |
| 212 | super(_open_file, self).__init__(open(name, mode)) |
| 213 | |
| 214 | def __exit__(self, *args): |
| 215 | self.file_like.close() |
| 216 | |
| 217 | |
| 218 | class _open_buffer_reader(_opener): |
David Riazati | dca123e | 2019-11-19 10:14:44 -0800 | [diff] [blame] | 219 | def __init__(self, buffer): |
| 220 | super(_open_buffer_reader, self).__init__(buffer) |
| 221 | _check_seekable(buffer) |
Your Name | bfedace | 2019-11-14 13:35:37 -0800 | [diff] [blame] | 222 | |
| 223 | |
| 224 | class _open_buffer_writer(_opener): |
| 225 | def __exit__(self, *args): |
| 226 | self.file_like.flush() |
| 227 | |
| 228 | |
| 229 | def _open_file_like(name_or_buffer, mode): |
| 230 | if _is_path(name_or_buffer): |
| 231 | return _open_file(name_or_buffer, mode) |
| 232 | else: |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 233 | if 'w' in mode: |
Your Name | bfedace | 2019-11-14 13:35:37 -0800 | [diff] [blame] | 234 | return _open_buffer_writer(name_or_buffer) |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 235 | elif 'r' in mode: |
Your Name | bfedace | 2019-11-14 13:35:37 -0800 | [diff] [blame] | 236 | return _open_buffer_reader(name_or_buffer) |
| 237 | else: |
Nikita Shulga | 0c01f13 | 2020-09-04 07:36:47 -0700 | [diff] [blame] | 238 | raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}") |
Your Name | bfedace | 2019-11-14 13:35:37 -0800 | [diff] [blame] | 239 | |
| 240 | |
Your Name | bfedace | 2019-11-14 13:35:37 -0800 | [diff] [blame] | 241 | class _open_zipfile_reader(_opener): |
Nikita Shulga | 591fffc | 2020-07-02 07:09:21 -0700 | [diff] [blame] | 242 | def __init__(self, name_or_buffer) -> None: |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 243 | super(_open_zipfile_reader, self).__init__(torch._C.PyTorchFileReader(name_or_buffer)) |
Your Name | bfedace | 2019-11-14 13:35:37 -0800 | [diff] [blame] | 244 | |
| 245 | |
| 246 | class _open_zipfile_writer_file(_opener): |
Nikita Shulga | 591fffc | 2020-07-02 07:09:21 -0700 | [diff] [blame] | 247 | def __init__(self, name) -> None: |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 248 | super(_open_zipfile_writer_file, self).__init__(torch._C.PyTorchFileWriter(str(name))) |
Your Name | bfedace | 2019-11-14 13:35:37 -0800 | [diff] [blame] | 249 | |
Nikita Shulga | 591fffc | 2020-07-02 07:09:21 -0700 | [diff] [blame] | 250 | def __exit__(self, *args) -> None: |
Your Name | bfedace | 2019-11-14 13:35:37 -0800 | [diff] [blame] | 251 | self.file_like.write_end_of_file() |
| 252 | |
| 253 | |
| 254 | class _open_zipfile_writer_buffer(_opener): |
Nikita Shulga | 591fffc | 2020-07-02 07:09:21 -0700 | [diff] [blame] | 255 | def __init__(self, buffer) -> None: |
Your Name | bfedace | 2019-11-14 13:35:37 -0800 | [diff] [blame] | 256 | self.buffer = buffer |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 257 | super(_open_zipfile_writer_buffer, self).__init__(torch._C.PyTorchFileWriter(buffer)) |
Your Name | bfedace | 2019-11-14 13:35:37 -0800 | [diff] [blame] | 258 | |
Nikita Shulga | 591fffc | 2020-07-02 07:09:21 -0700 | [diff] [blame] | 259 | def __exit__(self, *args) -> None: |
Your Name | bfedace | 2019-11-14 13:35:37 -0800 | [diff] [blame] | 260 | self.file_like.write_end_of_file() |
| 261 | self.buffer.flush() |
| 262 | |
| 263 | |
| 264 | def _open_zipfile_writer(name_or_buffer): |
Nikita Shulga | 591fffc | 2020-07-02 07:09:21 -0700 | [diff] [blame] | 265 | container: Type[_opener] |
Your Name | bfedace | 2019-11-14 13:35:37 -0800 | [diff] [blame] | 266 | if _is_path(name_or_buffer): |
| 267 | container = _open_zipfile_writer_file |
| 268 | else: |
| 269 | container = _open_zipfile_writer_buffer |
| 270 | return container(name_or_buffer) |
Edward Z. Yang | 57eb8bd | 2017-08-31 08:46:30 -0700 | [diff] [blame] | 271 | |
| 272 | |
Nikita Shulga | 591fffc | 2020-07-02 07:09:21 -0700 | [diff] [blame] | 273 | def _is_compressed_file(f) -> bool: |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 274 | compress_modules = ['gzip'] |
li-roy | bafec16 | 2018-05-31 12:06:38 -0700 | [diff] [blame] | 275 | try: |
| 276 | return f.__module__ in compress_modules |
| 277 | except AttributeError: |
| 278 | return False |
| 279 | |
| 280 | |
| 281 | def _should_read_directly(f): |
| 282 | """ |
| 283 | Checks if f is a file that should be read directly. It should be read |
| 284 | directly if it is backed by a real file (has a fileno) and is not a |
| 285 | a compressed file (e.g. gzip) |
| 286 | """ |
| 287 | if _is_compressed_file(f): |
| 288 | return False |
Richard Zou | 8ba8713 | 2018-03-08 22:18:55 -0500 | [diff] [blame] | 289 | try: |
| 290 | return f.fileno() >= 0 |
| 291 | except io.UnsupportedOperation: |
| 292 | return False |
| 293 | except AttributeError: |
| 294 | return False |
| 295 | |
| 296 | |
Nikita Shulga | 591fffc | 2020-07-02 07:09:21 -0700 | [diff] [blame] | 297 | def _check_seekable(f) -> bool: |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 298 | |
Wei Yang | c3e4b3c | 2018-06-12 12:57:28 -0700 | [diff] [blame] | 299 | def raise_err_msg(patterns, e): |
| 300 | for p in patterns: |
| 301 | if p in str(e): |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 302 | msg = (str(e) + ". You can only torch.load from a file that is seekable." |
| 303 | + " Please pre-load the data into a buffer like io.BytesIO and" |
| 304 | + " try to load from it instead.") |
Wei Yang | c3e4b3c | 2018-06-12 12:57:28 -0700 | [diff] [blame] | 305 | raise type(e)(msg) |
| 306 | raise e |
| 307 | |
| 308 | try: |
| 309 | f.seek(f.tell()) |
| 310 | return True |
| 311 | except (io.UnsupportedOperation, AttributeError) as e: |
| 312 | raise_err_msg(["seek", "tell"], e) |
Nikita Shulga | 591fffc | 2020-07-02 07:09:21 -0700 | [diff] [blame] | 313 | return False |
Wei Yang | c3e4b3c | 2018-06-12 12:57:28 -0700 | [diff] [blame] | 314 | |
Nikita Shulga | 591fffc | 2020-07-02 07:09:21 -0700 | [diff] [blame] | 315 | def _check_dill_version(pickle_module) -> None: |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 316 | '''Checks if using dill as the pickle module, and if so, checks if it is the correct version. |
Kurt Mohler | 3694749 | 2019-12-18 08:03:39 -0800 | [diff] [blame] | 317 | If dill version is lower than 0.3.1, a ValueError is raised. |
| 318 | |
| 319 | Args: |
| 320 | pickle_module: module used for pickling metadata and objects |
| 321 | |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 322 | ''' |
| 323 | if pickle_module.__name__ == 'dill': |
Kurt Mohler | 3694749 | 2019-12-18 08:03:39 -0800 | [diff] [blame] | 324 | required_dill_version = (0, 3, 1) |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 325 | if not check_module_version_greater_or_equal(pickle_module, required_dill_version, False): |
| 326 | raise ValueError(( |
| 327 | "'torch' supports dill >= %s, but you have dill %s." |
| 328 | " Please upgrade dill or switch to 'pickle'" |
| 329 | ) % ( |
| 330 | '.'.join([str(num) for num in required_dill_version]), |
| 331 | pickle_module.__version__ |
| 332 | )) |
Wei Yang | c3e4b3c | 2018-06-12 12:57:28 -0700 | [diff] [blame] | 333 | |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 334 | def save(obj, f: Union[str, os.PathLike, BinaryIO, IO[bytes]], |
| 335 | pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) -> None: |
Yukio Siraichi | 9d54475 | 2021-04-27 10:56:41 -0700 | [diff] [blame] | 336 | # Reference: https://github.com/pytorch/pytorch/issues/54354 |
| 337 | # The first line of this docstring overrides the one Sphinx generates for the |
| 338 | # documentation. We need it so that Sphinx doesn't leak `pickle`s path from |
| 339 | # the build environment (e.g. `<module 'pickle' from '/leaked/path'). |
| 340 | |
| 341 | """save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) |
| 342 | |
| 343 | Saves an object to a disk file. |
Adam Paszke | 0c9670d | 2016-10-04 11:33:00 -0700 | [diff] [blame] | 344 | |
Jeff Yang | 4752516 | 2021-03-29 10:01:52 -0700 | [diff] [blame] | 345 | See also: :ref:`saving-loading-tensors` |
Eli Stevens | b87c113 | 2017-02-26 05:33:26 -0800 | [diff] [blame] | 346 | |
Adam Paszke | 0c9670d | 2016-10-04 11:33:00 -0700 | [diff] [blame] | 347 | Args: |
| 348 | obj: saved object |
James Reed | 3ecae99 | 2020-06-30 10:05:57 -0700 | [diff] [blame] | 349 | f: a file-like object (has to implement write and flush) or a string or |
| 350 | os.PathLike object containing a file name |
Adam Paszke | 0c9670d | 2016-10-04 11:33:00 -0700 | [diff] [blame] | 351 | pickle_module: module used for pickling metadata and objects |
| 352 | pickle_protocol: can be specified to override the default protocol |
Richard Zou | 8ba8713 | 2018-03-08 22:18:55 -0500 | [diff] [blame] | 353 | |
KushajveerSingh | 88fe05e | 2020-06-05 12:55:39 -0700 | [diff] [blame] | 354 | .. note:: |
| 355 | A common PyTorch convention is to save tensors using .pt file extension. |
| 356 | |
James Reed | 0d24ed0 | 2020-06-22 18:37:33 -0700 | [diff] [blame] | 357 | .. note:: |
mattip | 75155df | 2020-07-07 11:37:06 -0700 | [diff] [blame] | 358 | PyTorch preserves storage sharing across serialization. See |
Jeff Yang | 4752516 | 2021-03-29 10:01:52 -0700 | [diff] [blame] | 359 | :ref:`preserve-storage-sharing` for more details. |
Ailing Zhang | d7cd168 | 2020-06-29 17:21:47 -0700 | [diff] [blame] | 360 | |
| 361 | .. note:: |
James Reed | 0d24ed0 | 2020-06-22 18:37:33 -0700 | [diff] [blame] | 362 | The 1.6 release of PyTorch switched ``torch.save`` to use a new |
| 363 | zipfile-based file format. ``torch.load`` still retains the ability to |
| 364 | load files in the old format. If for any reason you want ``torch.save`` |
| 365 | to use the old format, pass the kwarg ``_use_new_zipfile_serialization=False``. |
| 366 | |
Richard Zou | 8ba8713 | 2018-03-08 22:18:55 -0500 | [diff] [blame] | 367 | Example: |
Vishwak Srinivasan | 76a283d | 2018-03-13 19:17:43 +0530 | [diff] [blame] | 368 | >>> # Save to file |
li-roy | d564ecb | 2018-04-21 04:35:37 -0700 | [diff] [blame] | 369 | >>> x = torch.tensor([0, 1, 2, 3, 4]) |
Richard Zou | 8ba8713 | 2018-03-08 22:18:55 -0500 | [diff] [blame] | 370 | >>> torch.save(x, 'tensor.pt') |
Vishwak Srinivasan | 76a283d | 2018-03-13 19:17:43 +0530 | [diff] [blame] | 371 | >>> # Save to io.BytesIO buffer |
Richard Zou | 8ba8713 | 2018-03-08 22:18:55 -0500 | [diff] [blame] | 372 | >>> buffer = io.BytesIO() |
| 373 | >>> torch.save(x, buffer) |
Adam Paszke | 0c9670d | 2016-10-04 11:33:00 -0700 | [diff] [blame] | 374 | """ |
Kurt Mohler | 3694749 | 2019-12-18 08:03:39 -0800 | [diff] [blame] | 375 | _check_dill_version(pickle_module) |
| 376 | |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 377 | with _open_file_like(f, 'wb') as opened_file: |
James Reed | 3ecae99 | 2020-06-30 10:05:57 -0700 | [diff] [blame] | 378 | if _use_new_zipfile_serialization: |
| 379 | with _open_zipfile_writer(opened_file) as opened_zipfile: |
| 380 | _save(obj, opened_zipfile, pickle_module, pickle_protocol) |
| 381 | return |
David Riazati | dca123e | 2019-11-19 10:14:44 -0800 | [diff] [blame] | 382 | _legacy_save(obj, opened_file, pickle_module, pickle_protocol) |
Adam Paszke | e867baa | 2016-11-01 13:22:01 +0100 | [diff] [blame] | 383 | |
| 384 | |
Nikita Shulga | 591fffc | 2020-07-02 07:09:21 -0700 | [diff] [blame] | 385 | def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None: |
Sam Gross | e3e786e | 2016-11-03 16:29:14 -0400 | [diff] [blame] | 386 | import torch.nn as nn |
Sam Gross | e3e786e | 2016-11-03 16:29:14 -0400 | [diff] [blame] | 387 | serialized_container_types = {} |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 388 | serialized_storages = {} |
Adam Paszke | 686e8d3 | 2016-08-22 22:11:50 -0400 | [diff] [blame] | 389 | |
Kurt Mohler | bc3d380 | 2021-11-16 08:41:14 -0800 | [diff] [blame] | 390 | # Since loading storages that view the same data with different dtypes is |
| 391 | # not supported, we need to keep track of the dtype associated with each |
| 392 | # storage data_ptr and throw an error if the dtype is ever different. |
| 393 | # TODO: This feature could be added in the future |
| 394 | storage_dtypes: Dict[int, torch.dtype] = {} |
| 395 | |
Nikita Shulga | 591fffc | 2020-07-02 07:09:21 -0700 | [diff] [blame] | 396 | def persistent_id(obj: Any) -> Optional[Tuple]: |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 397 | # FIXME: the docs say that persistent_id should only return a string |
| 398 | # but torch store returns tuples. This works only in the binary protocol |
| 399 | # see |
| 400 | # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects |
| 401 | # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 |
Adam Paszke | d6fa3b3 | 2017-01-16 21:04:05 +0100 | [diff] [blame] | 402 | if isinstance(obj, type) and issubclass(obj, nn.Module): |
Sam Gross | e3e786e | 2016-11-03 16:29:14 -0400 | [diff] [blame] | 403 | if obj in serialized_container_types: |
| 404 | return None |
| 405 | serialized_container_types[obj] = True |
Adam Paszke | 2bd7a3c | 2016-12-01 20:35:35 +0100 | [diff] [blame] | 406 | source_file = source = None |
| 407 | try: |
Dmytro Dzhulgakov | df338f8 | 2019-09-14 21:26:23 -0700 | [diff] [blame] | 408 | source_lines, _, source_file = get_source_lines_and_file(obj) |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 409 | source = ''.join(source_lines) |
Sam Gross | 8e58135 | 2017-10-23 23:03:37 -0400 | [diff] [blame] | 410 | except Exception: # saving the source is optional, so we can ignore any errors |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 411 | warnings.warn("Couldn't retrieve source code for container of " |
| 412 | "type " + obj.__name__ + ". It won't be checked " |
| 413 | "for correctness upon loading.") |
| 414 | return ('module', obj, source_file, source) |
Geeta Chauhan | 9e314f5 | 2019-11-04 23:15:30 -0800 | [diff] [blame] | 415 | |
Kurt Mohler | 8e7fe87 | 2022-02-15 15:43:57 -0800 | [diff] [blame] | 416 | if isinstance(obj, torch.storage._TypedStorage) or torch.is_storage(obj): |
| 417 | if isinstance(obj, torch.storage._TypedStorage): |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 418 | # TODO: Once we decide to break serialization FC, this case |
| 419 | # can be deleted |
| 420 | storage = obj._storage |
Kurt Mohler | bc3d380 | 2021-11-16 08:41:14 -0800 | [diff] [blame] | 421 | storage_dtype = obj.dtype |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 422 | storage_type_str = obj.pickle_storage_type() |
| 423 | storage_type = getattr(torch, storage_type_str) |
| 424 | dtype = obj.dtype |
| 425 | storage_numel = obj.size() |
| 426 | |
| 427 | else: |
| 428 | storage = obj |
Kurt Mohler | bc3d380 | 2021-11-16 08:41:14 -0800 | [diff] [blame] | 429 | storage_dtype = storage.dtype |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 430 | storage_type = normalize_storage_type(type(obj)) |
| 431 | dtype = torch.uint8 |
| 432 | storage_numel = cast(Storage, storage).nbytes() |
| 433 | |
Kurt Mohler | b69155f | 2021-11-24 09:50:11 -0800 | [diff] [blame] | 434 | # If storage is allocated, ensure that any other saved storages |
| 435 | # pointing to the same data all have the same dtype. If storage is |
| 436 | # not allocated, don't perform this check |
| 437 | if storage.data_ptr() != 0: |
| 438 | if storage.data_ptr() in storage_dtypes: |
| 439 | if storage_dtype != storage_dtypes[storage.data_ptr()]: |
| 440 | raise RuntimeError( |
| 441 | 'Cannot save multiple tensors or storages that ' |
| 442 | 'view the same data as different types') |
| 443 | else: |
| 444 | storage_dtypes[storage.data_ptr()] = storage_dtype |
Kurt Mohler | bc3d380 | 2021-11-16 08:41:14 -0800 | [diff] [blame] | 445 | |
Nikita Shulga | 591fffc | 2020-07-02 07:09:21 -0700 | [diff] [blame] | 446 | view_metadata: Optional[Tuple[str, int, int]] |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 447 | storage = cast(Storage, storage) |
| 448 | |
Edward Yang | 976f925 | 2018-07-16 15:17:57 -0700 | [diff] [blame] | 449 | # Offset is always 0, but we keep it for backwards compatibility |
| 450 | # with the old serialization format (which supported storage views) |
| 451 | offset = 0 |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 452 | storage_key = str(storage._cdata) |
| 453 | location = location_tag(storage) |
| 454 | |
| 455 | # TODO: There's an issue here with FC. It might be impossible to |
| 456 | # solve, but it's worth noting. Imagine we save a list `[storage, |
| 457 | # tensor]`, where `tensor.storage()` is the same as `storage`, and |
| 458 | # `tensor.element_size() > 1`. Let's say that `tensor.dtype == |
| 459 | # torch.float`. The storage will be serialized with element size |
| 460 | # of 1, since we're choosing to serialize the first occurance of |
| 461 | # a duplicate storage. Since this legacy serialization format saves |
| 462 | # the numel of the storage, rather than nbytes directly, we'll be |
| 463 | # effectively saving nbytes in this case. We'll be able to load it |
| 464 | # and the tensor back up with no problems in _this_ and future |
| 465 | # versions of pytorch, but in older versions, here's the problem: |
Kurt Mohler | 8e7fe87 | 2022-02-15 15:43:57 -0800 | [diff] [blame] | 466 | # the storage will be loaded up as a _UntypedStorage, and then the |
| 467 | # FloatTensor will loaded and the _UntypedStorage will be assigned to |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 468 | # it. Since the storage dtype does not match the tensor dtype, this |
| 469 | # will cause an error. If we reverse the list, like `[tensor, |
| 470 | # storage]`, then we will save the `tensor.storage()` as a faked |
| 471 | # `FloatStorage`, and the saved size will be the correct |
| 472 | # dtype-specific numel count that old versions expect. `tensor` |
| 473 | # will be able to load up properly in old versions, pointing to |
| 474 | # a FloatStorage. However, `storage` is still being translated to |
Kurt Mohler | 8e7fe87 | 2022-02-15 15:43:57 -0800 | [diff] [blame] | 475 | # a _UntypedStorage, and it will try to resolve to the same |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 476 | # FloatStorage that `tensor` contains. This will also cause an |
| 477 | # error. It doesn't seem like there's any way around this. |
| 478 | # Probably, we just cannot maintain FC for the legacy format if the |
| 479 | # saved list contains both a tensor and a storage that point to the |
| 480 | # same data. We should still be able to maintain FC for lists of |
| 481 | # just tensors, as long as all views share the same dtype as the |
| 482 | # tensor they are viewing. |
| 483 | |
| 484 | if storage_key not in serialized_storages: |
| 485 | serialized_storages[storage_key] = (storage, dtype) |
| 486 | is_view = storage._cdata != storage._cdata |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 487 | if is_view: |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 488 | view_metadata = (str(storage._cdata), offset, storage.nbytes()) |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 489 | else: |
| 490 | view_metadata = None |
| 491 | |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 492 | res = ('storage', |
| 493 | storage_type, |
| 494 | storage_key, |
| 495 | location, |
| 496 | storage_numel, |
| 497 | view_metadata) |
| 498 | return res |
Adam Paszke | 686e8d3 | 2016-08-22 22:11:50 -0400 | [diff] [blame] | 499 | return None |
| 500 | |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 501 | sys_info = dict( |
| 502 | protocol_version=PROTOCOL_VERSION, |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 503 | little_endian=sys.byteorder == 'little', |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 504 | type_sizes=dict( |
| 505 | short=SHORT_SIZE, |
| 506 | int=INT_SIZE, |
| 507 | long=LONG_SIZE, |
| 508 | ), |
| 509 | ) |
Adam Paszke | 686e8d3 | 2016-08-22 22:11:50 -0400 | [diff] [blame] | 510 | |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 511 | pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol) |
| 512 | pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol) |
| 513 | pickle_module.dump(sys_info, f, protocol=pickle_protocol) |
| 514 | pickler = pickle_module.Pickler(f, protocol=pickle_protocol) |
| 515 | pickler.persistent_id = persistent_id |
| 516 | pickler.dump(obj) |
Adam Paszke | 686e8d3 | 2016-08-22 22:11:50 -0400 | [diff] [blame] | 517 | |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 518 | serialized_storage_keys = sorted(serialized_storages.keys()) |
| 519 | pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol) |
| 520 | f.flush() |
| 521 | for key in serialized_storage_keys: |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 522 | storage, dtype = serialized_storages[key] |
| 523 | storage._write_file(f, _should_read_directly(f), True, torch._utils._element_size(dtype)) |
Adam Paszke | 686e8d3 | 2016-08-22 22:11:50 -0400 | [diff] [blame] | 524 | |
| 525 | |
David Riazati | dca123e | 2019-11-19 10:14:44 -0800 | [diff] [blame] | 526 | def _save(obj, zip_file, pickle_module, pickle_protocol): |
| 527 | serialized_storages = {} |
Francesco Casalegno | fea3824 | 2021-05-10 11:48:36 -0700 | [diff] [blame] | 528 | id_map: Dict[int, str] = {} |
David Riazati | dca123e | 2019-11-19 10:14:44 -0800 | [diff] [blame] | 529 | |
Kurt Mohler | bc3d380 | 2021-11-16 08:41:14 -0800 | [diff] [blame] | 530 | # Since loading storages that view the same data with different dtypes is |
| 531 | # not supported, we need to keep track of the dtype associated with each |
| 532 | # storage data_ptr and throw an error if the dtype is ever different. |
| 533 | # TODO: This feature could be added in the future |
| 534 | storage_dtypes: Dict[int, torch.dtype] = {} |
| 535 | |
David Riazati | dca123e | 2019-11-19 10:14:44 -0800 | [diff] [blame] | 536 | def persistent_id(obj): |
| 537 | # FIXME: the docs say that persistent_id should only return a string |
| 538 | # but torch store returns tuples. This works only in the binary protocol |
| 539 | # see |
| 540 | # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects |
| 541 | # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 |
Kurt Mohler | 8e7fe87 | 2022-02-15 15:43:57 -0800 | [diff] [blame] | 542 | if isinstance(obj, torch.storage._TypedStorage) or torch.is_storage(obj): |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 543 | |
Kurt Mohler | 8e7fe87 | 2022-02-15 15:43:57 -0800 | [diff] [blame] | 544 | if isinstance(obj, torch.storage._TypedStorage): |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 545 | # TODO: Once we decide to break serialization FC, this case |
| 546 | # can be deleted |
| 547 | storage = obj._storage |
Kurt Mohler | bc3d380 | 2021-11-16 08:41:14 -0800 | [diff] [blame] | 548 | storage_dtype = obj.dtype |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 549 | storage_type_str = obj.pickle_storage_type() |
| 550 | storage_type = getattr(torch, storage_type_str) |
| 551 | storage_numel = obj.size() |
| 552 | |
| 553 | else: |
| 554 | storage = obj |
Kurt Mohler | bc3d380 | 2021-11-16 08:41:14 -0800 | [diff] [blame] | 555 | storage_dtype = storage.dtype |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 556 | storage_type = normalize_storage_type(type(obj)) |
| 557 | storage_numel = storage.nbytes() |
| 558 | |
| 559 | storage = cast(Storage, storage) |
Kurt Mohler | bc3d380 | 2021-11-16 08:41:14 -0800 | [diff] [blame] | 560 | |
Kurt Mohler | b69155f | 2021-11-24 09:50:11 -0800 | [diff] [blame] | 561 | # If storage is allocated, ensure that any other saved storages |
| 562 | # pointing to the same data all have the same dtype. If storage is |
| 563 | # not allocated, don't perform this check |
| 564 | if storage.data_ptr() != 0: |
| 565 | if storage.data_ptr() in storage_dtypes: |
| 566 | if storage_dtype != storage_dtypes[storage.data_ptr()]: |
| 567 | raise RuntimeError( |
| 568 | 'Cannot save multiple tensors or storages that ' |
| 569 | 'view the same data as different types') |
| 570 | else: |
| 571 | storage_dtypes[storage.data_ptr()] = storage_dtype |
Kurt Mohler | bc3d380 | 2021-11-16 08:41:14 -0800 | [diff] [blame] | 572 | |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 573 | storage_key = id_map.setdefault(storage._cdata, str(len(id_map))) |
| 574 | location = location_tag(storage) |
| 575 | serialized_storages[storage_key] = storage |
David Riazati | dca123e | 2019-11-19 10:14:44 -0800 | [diff] [blame] | 576 | |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 577 | return ('storage', |
| 578 | storage_type, |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 579 | storage_key, |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 580 | location, |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 581 | storage_numel) |
| 582 | |
David Riazati | dca123e | 2019-11-19 10:14:44 -0800 | [diff] [blame] | 583 | return None |
| 584 | |
| 585 | # Write the pickle data for `obj` |
| 586 | data_buf = io.BytesIO() |
| 587 | pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol) |
| 588 | pickler.persistent_id = persistent_id |
| 589 | pickler.dump(obj) |
| 590 | data_value = data_buf.getvalue() |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 591 | zip_file.write_record('data.pkl', data_value, len(data_value)) |
David Riazati | dca123e | 2019-11-19 10:14:44 -0800 | [diff] [blame] | 592 | |
| 593 | # Write each tensor to a file named tensor/the_tensor_key in the zip archive |
| 594 | for key in sorted(serialized_storages.keys()): |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 595 | name = f'data/{key}' |
David Riazati | dca123e | 2019-11-19 10:14:44 -0800 | [diff] [blame] | 596 | storage = serialized_storages[key] |
Thomas Viehmann | 7b7f251 | 2020-10-13 12:48:22 -0700 | [diff] [blame] | 597 | # given that we copy things around anyway, we might use storage.cpu() |
| 598 | # this means to that to get tensors serialized, you need to implement |
| 599 | # .cpu() on the underlying Storage |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 600 | if storage.device.type != 'cpu': |
Thomas Viehmann | 7b7f251 | 2020-10-13 12:48:22 -0700 | [diff] [blame] | 601 | storage = storage.cpu() |
| 602 | # Now that it is on the CPU we can directly copy it into the zip file |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 603 | num_bytes = storage.nbytes() |
Thomas Viehmann | 7b7f251 | 2020-10-13 12:48:22 -0700 | [diff] [blame] | 604 | zip_file.write_record(name, storage.data_ptr(), num_bytes) |
David Riazati | dca123e | 2019-11-19 10:14:44 -0800 | [diff] [blame] | 605 | |
| 606 | |
SsnL | 54d5c53 | 2018-12-10 08:05:06 -0800 | [diff] [blame] | 607 | def load(f, map_location=None, pickle_module=pickle, **pickle_load_args): |
Yukio Siraichi | 9d54475 | 2021-04-27 10:56:41 -0700 | [diff] [blame] | 608 | # Reference: https://github.com/pytorch/pytorch/issues/54354 |
| 609 | # The first line of this docstring overrides the one Sphinx generates for the |
| 610 | # documentation. We need it so that Sphinx doesn't leak `pickle`s path from |
| 611 | # the build environment (e.g. `<module 'pickle' from '/leaked/path'). |
| 612 | |
| 613 | """load(f, map_location=None, pickle_module=pickle, **pickle_load_args) |
| 614 | |
| 615 | Loads an object saved with :func:`torch.save` from a file. |
Adam Paszke | 0c9670d | 2016-10-04 11:33:00 -0700 | [diff] [blame] | 616 | |
Tongzhou Wang | 51ee048 | 2019-06-13 12:08:52 -0700 | [diff] [blame] | 617 | :func:`torch.load` uses Python's unpickling facilities but treats storages, |
greaber | 490d5c2 | 2017-10-14 19:54:53 +0300 | [diff] [blame] | 618 | which underlie tensors, specially. They are first deserialized on the |
| 619 | CPU and are then moved to the device they were saved from. If this fails |
| 620 | (e.g. because the run time system doesn't have certain devices), an exception |
| 621 | is raised. However, storages can be dynamically remapped to an alternative |
Tongzhou Wang | 51ee048 | 2019-06-13 12:08:52 -0700 | [diff] [blame] | 622 | set of devices using the :attr:`map_location` argument. |
Adam Paszke | 0c9670d | 2016-10-04 11:33:00 -0700 | [diff] [blame] | 623 | |
Tongzhou Wang | 51ee048 | 2019-06-13 12:08:52 -0700 | [diff] [blame] | 624 | If :attr:`map_location` is a callable, it will be called once for each serialized |
greaber | 490d5c2 | 2017-10-14 19:54:53 +0300 | [diff] [blame] | 625 | storage with two arguments: storage and location. The storage argument |
| 626 | will be the initial deserialization of the storage, residing on the CPU. |
| 627 | Each serialized storage has a location tag associated with it which |
| 628 | identifies the device it was saved from, and this tag is the second |
Tongzhou Wang | 51ee048 | 2019-06-13 12:08:52 -0700 | [diff] [blame] | 629 | argument passed to :attr:`map_location`. The builtin location tags are ``'cpu'`` |
| 630 | for CPU tensors and ``'cuda:device_id'`` (e.g. ``'cuda:2'``) for CUDA tensors. |
| 631 | :attr:`map_location` should return either ``None`` or a storage. If |
| 632 | :attr:`map_location` returns a storage, it will be used as the final deserialized |
| 633 | object, already moved to the right device. Otherwise, :func:`torch.load` will |
| 634 | fall back to the default behavior, as if :attr:`map_location` wasn't specified. |
greaber | 490d5c2 | 2017-10-14 19:54:53 +0300 | [diff] [blame] | 635 | |
Yash | 293fa5f | 2020-02-21 09:27:19 -0800 | [diff] [blame] | 636 | If :attr:`map_location` is a :class:`torch.device` object or a string containing |
Tongzhou Wang | 51ee048 | 2019-06-13 12:08:52 -0700 | [diff] [blame] | 637 | a device tag, it indicates the location where all tensors should be loaded. |
Adam Paszke | 8307f21 | 2017-12-15 17:50:20 -0500 | [diff] [blame] | 638 | |
Tongzhou Wang | 51ee048 | 2019-06-13 12:08:52 -0700 | [diff] [blame] | 639 | Otherwise, if :attr:`map_location` is a dict, it will be used to remap location tags |
greaber | 490d5c2 | 2017-10-14 19:54:53 +0300 | [diff] [blame] | 640 | appearing in the file (keys), to ones that specify where to put the |
| 641 | storages (values). |
| 642 | |
| 643 | User extensions can register their own location tags and tagging and |
Tongzhou Wang | 51ee048 | 2019-06-13 12:08:52 -0700 | [diff] [blame] | 644 | deserialization methods using :func:`torch.serialization.register_package`. |
Adam Paszke | 0c9670d | 2016-10-04 11:33:00 -0700 | [diff] [blame] | 645 | |
| 646 | Args: |
Zain Patel | bbeee48 | 2020-12-14 19:14:02 -0800 | [diff] [blame] | 647 | f: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`), |
James Reed | 3ecae99 | 2020-06-30 10:05:57 -0700 | [diff] [blame] | 648 | or a string or os.PathLike object containing a file name |
Tongzhou Wang | 51ee048 | 2019-06-13 12:08:52 -0700 | [diff] [blame] | 649 | map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage |
Leonid Vlasenkov | 46a868d | 2017-07-10 17:24:54 +0300 | [diff] [blame] | 650 | locations |
| 651 | pickle_module: module used for unpickling metadata and objects (has to |
Tongzhou Wang | 51ee048 | 2019-06-13 12:08:52 -0700 | [diff] [blame] | 652 | match the :attr:`pickle_module` used to serialize file) |
nmilosev | 5fc5248 | 2019-09-25 14:56:44 -0700 | [diff] [blame] | 653 | pickle_load_args: (Python 3 only) optional keyword arguments passed over to |
Tongzhou Wang | 51ee048 | 2019-06-13 12:08:52 -0700 | [diff] [blame] | 654 | :func:`pickle_module.load` and :func:`pickle_module.Unpickler`, e.g., |
nmilosev | 5fc5248 | 2019-09-25 14:56:44 -0700 | [diff] [blame] | 655 | :attr:`errors=...`. |
Sam Gross | c4d1318 | 2017-03-15 16:54:19 -0400 | [diff] [blame] | 656 | |
Edgar Andrés Margffoy Tuay | 90a259e | 2020-01-25 22:10:49 -0800 | [diff] [blame] | 657 | .. warning:: |
davidriazati | 74ce3a0 | 2020-02-05 15:30:21 -0800 | [diff] [blame] | 658 | :func:`torch.load()` uses ``pickle`` module implicitly, which is known to be insecure. |
Edgar Andrés Margffoy Tuay | 90a259e | 2020-01-25 22:10:49 -0800 | [diff] [blame] | 659 | It is possible to construct malicious pickle data which will execute arbitrary code |
| 660 | during unpickling. Never load data that could have come from an untrusted |
| 661 | source, or that could have been tampered with. **Only load data you trust**. |
| 662 | |
Ailing Zhang | d793473 | 2018-06-29 16:42:46 -0700 | [diff] [blame] | 663 | .. note:: |
Tongzhou Wang | 51ee048 | 2019-06-13 12:08:52 -0700 | [diff] [blame] | 664 | When you call :func:`torch.load()` on a file which contains GPU tensors, those tensors |
| 665 | will be loaded to GPU by default. You can call ``torch.load(.., map_location='cpu')`` |
Ailing Zhang | d793473 | 2018-06-29 16:42:46 -0700 | [diff] [blame] | 666 | and then :meth:`load_state_dict` to avoid GPU RAM surge when loading a model checkpoint. |
| 667 | |
SsnL | 54d5c53 | 2018-12-10 08:05:06 -0800 | [diff] [blame] | 668 | .. note:: |
nmilosev | 5fc5248 | 2019-09-25 14:56:44 -0700 | [diff] [blame] | 669 | By default, we decode byte strings as ``utf-8``. This is to avoid a common error |
| 670 | case ``UnicodeDecodeError: 'ascii' codec can't decode byte 0x...`` |
| 671 | when loading files saved by Python 2 in Python 3. If this default |
| 672 | is incorrect, you may use an extra :attr:`encoding` keyword argument to specify how |
Tongzhou Wang | 51ee048 | 2019-06-13 12:08:52 -0700 | [diff] [blame] | 673 | these objects should be loaded, e.g., :attr:`encoding='latin1'` decodes them |
| 674 | to strings using ``latin1`` encoding, and :attr:`encoding='bytes'` keeps them |
SsnL | 54d5c53 | 2018-12-10 08:05:06 -0800 | [diff] [blame] | 675 | as byte arrays which can be decoded later with ``byte_array.decode(...)``. |
| 676 | |
Sam Gross | c4d1318 | 2017-03-15 16:54:19 -0400 | [diff] [blame] | 677 | Example: |
| 678 | >>> torch.load('tensors.pt') |
| 679 | # Load all tensors onto the CPU |
Ethan Steinberg | 9fa1dff | 2018-05-10 12:50:00 -0700 | [diff] [blame] | 680 | >>> torch.load('tensors.pt', map_location=torch.device('cpu')) |
Adam Paszke | 8307f21 | 2017-12-15 17:50:20 -0500 | [diff] [blame] | 681 | # Load all tensors onto the CPU, using a function |
Sam Gross | c4d1318 | 2017-03-15 16:54:19 -0400 | [diff] [blame] | 682 | >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage) |
greaber | 490d5c2 | 2017-10-14 19:54:53 +0300 | [diff] [blame] | 683 | # Load all tensors onto GPU 1 |
| 684 | >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1)) |
Sam Gross | c4d1318 | 2017-03-15 16:54:19 -0400 | [diff] [blame] | 685 | # Map tensors from GPU 1 to GPU 0 |
| 686 | >>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'}) |
Richard Zou | 8ba8713 | 2018-03-08 22:18:55 -0500 | [diff] [blame] | 687 | # Load tensor from io.BytesIO object |
Thomas Viehmann | 6a6983e | 2019-01-29 11:19:51 -0800 | [diff] [blame] | 688 | >>> with open('tensor.pt', 'rb') as f: |
Sam Estep | c147aa3 | 2021-01-20 15:50:50 -0800 | [diff] [blame] | 689 | ... buffer = io.BytesIO(f.read()) |
Richard Zou | 8ba8713 | 2018-03-08 22:18:55 -0500 | [diff] [blame] | 690 | >>> torch.load(buffer) |
nmilosev | 5fc5248 | 2019-09-25 14:56:44 -0700 | [diff] [blame] | 691 | # Load a module with 'ascii' encoding for unpickling |
| 692 | >>> torch.load('module.pt', encoding='ascii') |
Adam Paszke | 0c9670d | 2016-10-04 11:33:00 -0700 | [diff] [blame] | 693 | """ |
Kurt Mohler | 3694749 | 2019-12-18 08:03:39 -0800 | [diff] [blame] | 694 | _check_dill_version(pickle_module) |
| 695 | |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 696 | if 'encoding' not in pickle_load_args.keys(): |
| 697 | pickle_load_args['encoding'] = 'utf-8' |
Your Name | fff4f16 | 2019-11-06 18:40:10 -0800 | [diff] [blame] | 698 | |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 699 | with _open_file_like(f, 'rb') as opened_file: |
David Riazati | dca123e | 2019-11-19 10:14:44 -0800 | [diff] [blame] | 700 | if _is_zipfile(opened_file): |
James Reed | 9c82b57 | 2020-07-06 08:58:28 -0700 | [diff] [blame] | 701 | # The zipfile reader is going to advance the current file position. |
| 702 | # If we want to actually tail call to torch.jit.load, we need to |
| 703 | # reset back to the original position. |
| 704 | orig_position = opened_file.tell() |
James Reed | 3ecae99 | 2020-06-30 10:05:57 -0700 | [diff] [blame] | 705 | with _open_zipfile_reader(opened_file) as opened_zipfile: |
David Riazati | 8c6f0c0 | 2019-11-22 12:28:49 -0800 | [diff] [blame] | 706 | if _is_torchscript_zip(opened_zipfile): |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 707 | warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive" |
| 708 | " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to" |
| 709 | " silence this warning)", UserWarning) |
James Reed | 9c82b57 | 2020-07-06 08:58:28 -0700 | [diff] [blame] | 710 | opened_file.seek(orig_position) |
| 711 | return torch.jit.load(opened_file) |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 712 | return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args) |
| 713 | return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args) |
Adam Paszke | e867baa | 2016-11-01 13:22:01 +0100 | [diff] [blame] | 714 | |
| 715 | |
Pearu Peterson | b7fb2b8 | 2019-10-04 08:07:44 -0700 | [diff] [blame] | 716 | # Register pickling support for layout instances such as |
| 717 | # torch.sparse_coo, etc |
| 718 | def _get_layout(name): |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 719 | """Get layout extension object from its string representation. |
| 720 | """ |
| 721 | cache = _get_layout.cache # type: ignore[attr-defined] |
Pearu Peterson | b7fb2b8 | 2019-10-04 08:07:44 -0700 | [diff] [blame] | 722 | if not cache: |
| 723 | for v in torch.__dict__.values(): |
| 724 | if isinstance(v, torch.layout): |
| 725 | cache[str(v)] = v |
| 726 | return cache[name] |
| 727 | |
Nikita Shulga | 591fffc | 2020-07-02 07:09:21 -0700 | [diff] [blame] | 728 | # There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087 |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 729 | _get_layout.cache = {} # type: ignore[attr-defined] |
Pearu Peterson | b7fb2b8 | 2019-10-04 08:07:44 -0700 | [diff] [blame] | 730 | copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),))) |
| 731 | |
| 732 | |
David Riazati | dca123e | 2019-11-19 10:14:44 -0800 | [diff] [blame] | 733 | def _legacy_load(f, map_location, pickle_module, **pickle_load_args): |
Nikita Shulga | 591fffc | 2020-07-02 07:09:21 -0700 | [diff] [blame] | 734 | deserialized_objects: Dict[int, Any] = {} |
Adam Paszke | 686e8d3 | 2016-08-22 22:11:50 -0400 | [diff] [blame] | 735 | |
David Riazati | dca123e | 2019-11-19 10:14:44 -0800 | [diff] [blame] | 736 | restore_location = _get_restore_location(map_location) |
Adam Paszke | 0c9670d | 2016-10-04 11:33:00 -0700 | [diff] [blame] | 737 | |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 738 | class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined] |
| 739 | |
| 740 | def find_class(self, mod_name, name): |
| 741 | if type(name) is str and 'Storage' in name: |
| 742 | try: |
| 743 | return StorageType(name) |
| 744 | except KeyError: |
| 745 | pass |
| 746 | return super().find_class(mod_name, name) |
| 747 | |
Sam Gross | e3e786e | 2016-11-03 16:29:14 -0400 | [diff] [blame] | 748 | def _check_container_source(container_type, source_file, original_source): |
Soumith Chintala | 5b142e5 | 2018-02-20 17:42:57 -0500 | [diff] [blame] | 749 | try: |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 750 | current_source = ''.join(get_source_lines_and_file(container_type)[0]) |
Soumith Chintala | 5b142e5 | 2018-02-20 17:42:57 -0500 | [diff] [blame] | 751 | except Exception: # saving the source is optional, so we can ignore any errors |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 752 | warnings.warn("Couldn't retrieve source code for container of " |
| 753 | "type " + container_type.__name__ + ". It won't be checked " |
| 754 | "for correctness upon loading.") |
Soumith Chintala | 5b142e5 | 2018-02-20 17:42:57 -0500 | [diff] [blame] | 755 | return |
Sam Gross | e3e786e | 2016-11-03 16:29:14 -0400 | [diff] [blame] | 756 | if original_source != current_source: |
| 757 | if container_type.dump_patches: |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 758 | file_name = container_type.__name__ + '.patch' |
| 759 | diff = difflib.unified_diff(current_source.split('\n'), |
| 760 | original_source.split('\n'), |
| 761 | source_file, |
| 762 | source_file, lineterm="") |
| 763 | lines = '\n'.join(diff) |
Sam Gross | e3e786e | 2016-11-03 16:29:14 -0400 | [diff] [blame] | 764 | try: |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 765 | with open(file_name, 'a+') as f: |
Sam Gross | e3e786e | 2016-11-03 16:29:14 -0400 | [diff] [blame] | 766 | file_size = f.seek(0, 2) |
| 767 | f.seek(0) |
| 768 | if file_size == 0: |
| 769 | f.write(lines) |
| 770 | elif file_size != len(lines) or f.read() != lines: |
| 771 | raise IOError |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 772 | msg = ("Saved a reverse patch to " + file_name + ". " |
| 773 | "Run `patch -p0 < " + file_name + "` to revert your " |
| 774 | "changes.") |
Sam Gross | e3e786e | 2016-11-03 16:29:14 -0400 | [diff] [blame] | 775 | except IOError: |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 776 | msg = ("Tried to save a patch, but couldn't create a " |
| 777 | "writable file " + file_name + ". Make sure it " |
| 778 | "doesn't exist and your working directory is " |
| 779 | "writable.") |
Sam Gross | e3e786e | 2016-11-03 16:29:14 -0400 | [diff] [blame] | 780 | else: |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 781 | msg = ("you can retrieve the original source code by " |
| 782 | "accessing the object's source attribute or set " |
| 783 | "`torch.nn.Module.dump_patches = True` and use the " |
| 784 | "patch tool to revert the changes.") |
Nikita Shulga | 0c01f13 | 2020-09-04 07:36:47 -0700 | [diff] [blame] | 785 | msg = f"source code of class '{torch.typename(container_type)}' has changed. {msg}" |
Sam Gross | e3e786e | 2016-11-03 16:29:14 -0400 | [diff] [blame] | 786 | warnings.warn(msg, SourceChangeWarning) |
| 787 | |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 788 | def legacy_load(f): |
Nikita Shulga | 591fffc | 2020-07-02 07:09:21 -0700 | [diff] [blame] | 789 | deserialized_objects: Dict[int, Any] = {} |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 790 | |
| 791 | def persistent_load(saved_id): |
| 792 | if isinstance(saved_id, tuple): |
| 793 | # Ignore containers that don't have any sources saved |
| 794 | if all(saved_id[1:]): |
| 795 | _check_container_source(*saved_id) |
| 796 | return saved_id[0] |
| 797 | return deserialized_objects[int(saved_id)] |
| 798 | |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 799 | with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \ |
| 800 | mkdtemp() as tmpdir: |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 801 | |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 802 | tar.extract('storages', path=tmpdir) |
| 803 | with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f: |
SsnL | 54d5c53 | 2018-12-10 08:05:06 -0800 | [diff] [blame] | 804 | num_storages = pickle_module.load(f, **pickle_load_args) |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 805 | for i in range(num_storages): |
SsnL | 54d5c53 | 2018-12-10 08:05:06 -0800 | [diff] [blame] | 806 | args = pickle_module.load(f, **pickle_load_args) |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 807 | key, location, storage_type = args |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 808 | dtype = storage_type.dtype |
Kurt Mohler | 8e7fe87 | 2022-02-15 15:43:57 -0800 | [diff] [blame] | 809 | obj = cast(Storage, torch._UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype)) |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 810 | obj = restore_location(obj, location) |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 811 | # TODO: Once we decide to break serialization FC, we can |
Kurt Mohler | 8e7fe87 | 2022-02-15 15:43:57 -0800 | [diff] [blame] | 812 | # stop wrapping with _TypedStorage |
| 813 | deserialized_objects[key] = torch.storage._TypedStorage( |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 814 | wrap_storage=obj, |
| 815 | dtype=dtype) |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 816 | |
SsnL | 54d5c53 | 2018-12-10 08:05:06 -0800 | [diff] [blame] | 817 | storage_views = pickle_module.load(f, **pickle_load_args) |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 818 | for target_cdata, root_cdata, offset, numel in storage_views: |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 819 | root = deserialized_objects[root_cdata] |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 820 | element_size = torch._utils._element_size(root.dtype) |
| 821 | offset_bytes = offset * element_size |
| 822 | # TODO: Once we decide to break serialization FC, we can |
Kurt Mohler | 8e7fe87 | 2022-02-15 15:43:57 -0800 | [diff] [blame] | 823 | # stop wrapping with _TypedStorage |
| 824 | deserialized_objects[target_cdata] = torch.storage._TypedStorage( |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 825 | wrap_storage=root._storage[offset_bytes:offset_bytes + numel * element_size], |
| 826 | dtype=root.dtype) |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 827 | |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 828 | tar.extract('tensors', path=tmpdir) |
| 829 | with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f: |
SsnL | 54d5c53 | 2018-12-10 08:05:06 -0800 | [diff] [blame] | 830 | num_tensors = pickle_module.load(f, **pickle_load_args) |
Sam Gross | 94a0c72 | 2017-12-05 11:24:54 -0500 | [diff] [blame] | 831 | for _ in range(num_tensors): |
SsnL | 54d5c53 | 2018-12-10 08:05:06 -0800 | [diff] [blame] | 832 | args = pickle_module.load(f, **pickle_load_args) |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 833 | key, storage_id, original_tensor_type = args |
| 834 | storage = deserialized_objects[storage_id] |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 835 | ndim, = struct.unpack('<i', f.read(4)) |
Sam Gross | 94a0c72 | 2017-12-05 11:24:54 -0500 | [diff] [blame] | 836 | # skip next 4 bytes; legacy encoding treated ndim as 8 bytes |
| 837 | f.read(4) |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 838 | numel = struct.unpack(f'<{ndim}q', f.read(8 * ndim)) |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 839 | stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim)) |
| 840 | storage_offset, = struct.unpack('<q', f.read(8)) |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 841 | tensor = torch.tensor([], dtype=storage.dtype).set_( |
| 842 | storage._storage, storage_offset, numel, stride) |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 843 | deserialized_objects[key] = tensor |
| 844 | |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 845 | pickle_file = tar.extractfile('pickle') |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 846 | unpickler = UnpicklerWrapper(pickle_file, **pickle_load_args) |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 847 | unpickler.persistent_load = persistent_load |
| 848 | result = unpickler.load() |
| 849 | return result |
| 850 | |
| 851 | deserialized_objects = {} |
| 852 | |
Adam Paszke | 686e8d3 | 2016-08-22 22:11:50 -0400 | [diff] [blame] | 853 | def persistent_load(saved_id): |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 854 | assert isinstance(saved_id, tuple) |
David Riazati | dca123e | 2019-11-19 10:14:44 -0800 | [diff] [blame] | 855 | typename = _maybe_decode_ascii(saved_id[0]) |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 856 | data = saved_id[1:] |
| 857 | |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 858 | if typename == 'module': |
Adam Paszke | 2bd7a3c | 2016-12-01 20:35:35 +0100 | [diff] [blame] | 859 | # Ignore containers that don't have any sources saved |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 860 | if all(data[1:]): |
| 861 | _check_container_source(*data) |
| 862 | return data[0] |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 863 | elif typename == 'storage': |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 864 | storage_type, root_key, location, numel, view_metadata = data |
David Riazati | dca123e | 2019-11-19 10:14:44 -0800 | [diff] [blame] | 865 | location = _maybe_decode_ascii(location) |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 866 | dtype = storage_type.dtype |
| 867 | |
| 868 | nbytes = numel * torch._utils._element_size(dtype) |
| 869 | |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 870 | if root_key not in deserialized_objects: |
Kurt Mohler | 8e7fe87 | 2022-02-15 15:43:57 -0800 | [diff] [blame] | 871 | obj = cast(Storage, torch._UntypedStorage(nbytes)) |
Luca Wehrstedt | 29f4f8f | 2019-02-21 01:24:56 -0800 | [diff] [blame] | 872 | obj._torch_load_uninitialized = True |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 873 | # TODO: Once we decide to break serialization FC, we can |
Kurt Mohler | 8e7fe87 | 2022-02-15 15:43:57 -0800 | [diff] [blame] | 874 | # stop wrapping with _TypedStorage |
| 875 | deserialized_objects[root_key] = torch.storage._TypedStorage( |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 876 | wrap_storage=restore_location(obj, location), |
| 877 | dtype=dtype) |
| 878 | |
| 879 | typed_storage = deserialized_objects[root_key] |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 880 | if view_metadata is not None: |
| 881 | view_key, offset, view_size = view_metadata |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 882 | offset_bytes = offset * torch._utils._element_size(dtype) |
| 883 | view_size_bytes = view_size * torch._utils._element_size(dtype) |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 884 | if view_key not in deserialized_objects: |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 885 | # TODO: Once we decide to break serialization FC, we can |
Kurt Mohler | 8e7fe87 | 2022-02-15 15:43:57 -0800 | [diff] [blame] | 886 | # stop wrapping with _TypedStorage |
| 887 | deserialized_objects[view_key] = torch.storage._TypedStorage( |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 888 | wrap_storage=typed_storage._storage[offset_bytes:offset_bytes + view_size_bytes], |
| 889 | dtype=dtype) |
| 890 | res = deserialized_objects[view_key] |
| 891 | |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 892 | else: |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 893 | res = typed_storage |
| 894 | return res |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 895 | else: |
| 896 | raise RuntimeError("Unknown saved id type: %s" % saved_id[0]) |
Adam Paszke | 686e8d3 | 2016-08-22 22:11:50 -0400 | [diff] [blame] | 897 | |
Wei Yang | c3e4b3c | 2018-06-12 12:57:28 -0700 | [diff] [blame] | 898 | _check_seekable(f) |
li-roy | bafec16 | 2018-05-31 12:06:38 -0700 | [diff] [blame] | 899 | f_should_read_directly = _should_read_directly(f) |
Wei Yang | c3e4b3c | 2018-06-12 12:57:28 -0700 | [diff] [blame] | 900 | |
li-roy | bafec16 | 2018-05-31 12:06:38 -0700 | [diff] [blame] | 901 | if f_should_read_directly and f.tell() == 0: |
Richard Zou | 8ba8713 | 2018-03-08 22:18:55 -0500 | [diff] [blame] | 902 | # legacy_load requires that f has fileno() |
Philipp Lang | c4b0db5 | 2017-11-17 20:21:37 +0000 | [diff] [blame] | 903 | # only if offset is zero we can attempt the legacy tar file loader |
| 904 | try: |
| 905 | return legacy_load(f) |
| 906 | except tarfile.TarError: |
davidriazati | 7a921ba | 2019-08-30 16:43:45 -0700 | [diff] [blame] | 907 | if _is_zipfile(f): |
David Riazati | 692898f | 2018-12-28 13:52:01 -0800 | [diff] [blame] | 908 | # .zip is used for torch.jit.save and will throw an un-pickling error here |
olramde | d770fbc | 2020-01-02 12:45:14 -0800 | [diff] [blame] | 909 | raise RuntimeError( |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 910 | f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)") from None |
Philipp Lang | c4b0db5 | 2017-11-17 20:21:37 +0000 | [diff] [blame] | 911 | # if not a tarfile, reset file offset and proceed |
Richard Zou | 8ba8713 | 2018-03-08 22:18:55 -0500 | [diff] [blame] | 912 | f.seek(0) |
Adam Paszke | 686e8d3 | 2016-08-22 22:11:50 -0400 | [diff] [blame] | 913 | |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 914 | if not hasattr(f, 'readinto') and (3, 8, 0) <= sys.version_info < (3, 8, 2): |
Nathan Goldbaum | 84101f3 | 2020-02-26 21:12:42 -0800 | [diff] [blame] | 915 | raise RuntimeError( |
| 916 | "torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. " |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 917 | f"Received object of type \"{type(f)}\". Please update to Python 3.8.2 or newer to restore this " |
| 918 | "functionality.") |
Nathan Goldbaum | 84101f3 | 2020-02-26 21:12:42 -0800 | [diff] [blame] | 919 | |
SsnL | 54d5c53 | 2018-12-10 08:05:06 -0800 | [diff] [blame] | 920 | magic_number = pickle_module.load(f, **pickle_load_args) |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 921 | if magic_number != MAGIC_NUMBER: |
| 922 | raise RuntimeError("Invalid magic number; corrupt file?") |
SsnL | 54d5c53 | 2018-12-10 08:05:06 -0800 | [diff] [blame] | 923 | protocol_version = pickle_module.load(f, **pickle_load_args) |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 924 | if protocol_version != PROTOCOL_VERSION: |
| 925 | raise RuntimeError("Invalid protocol version: %s" % protocol_version) |
Adam Paszke | 686e8d3 | 2016-08-22 22:11:50 -0400 | [diff] [blame] | 926 | |
SsnL | 54d5c53 | 2018-12-10 08:05:06 -0800 | [diff] [blame] | 927 | _sys_info = pickle_module.load(f, **pickle_load_args) |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 928 | unpickler = UnpicklerWrapper(f, **pickle_load_args) |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 929 | unpickler.persistent_load = persistent_load |
| 930 | result = unpickler.load() |
Adam Paszke | 4a8a185 | 2016-09-24 11:38:12 -0700 | [diff] [blame] | 931 | |
SsnL | 54d5c53 | 2018-12-10 08:05:06 -0800 | [diff] [blame] | 932 | deserialized_storage_keys = pickle_module.load(f, **pickle_load_args) |
Adam Paszke | 686e8d3 | 2016-08-22 22:11:50 -0400 | [diff] [blame] | 933 | |
li-roy | bafec16 | 2018-05-31 12:06:38 -0700 | [diff] [blame] | 934 | offset = f.tell() if f_should_read_directly else None |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 935 | for key in deserialized_storage_keys: |
| 936 | assert key in deserialized_objects |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 937 | typed_storage = deserialized_objects[key] |
| 938 | typed_storage._storage._set_from_file( |
| 939 | f, offset, f_should_read_directly, |
| 940 | torch._utils._element_size(typed_storage.dtype)) |
Philipp Lang | f23fb66 | 2019-05-09 08:16:20 -0700 | [diff] [blame] | 941 | if offset is not None: |
| 942 | offset = f.tell() |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 943 | |
Wojciech Baranowski | fcadca1 | 2020-06-30 22:29:22 -0700 | [diff] [blame] | 944 | torch._utils._validate_loaded_sparse_tensors() |
| 945 | |
Adam Lerer | e71cf20 | 2017-02-22 16:24:20 -0500 | [diff] [blame] | 946 | return result |
David Riazati | dca123e | 2019-11-19 10:14:44 -0800 | [diff] [blame] | 947 | |
| 948 | |
Nikita Shulga | 591fffc | 2020-07-02 07:09:21 -0700 | [diff] [blame] | 949 | def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str: |
David Riazati | dca123e | 2019-11-19 10:14:44 -0800 | [diff] [blame] | 950 | # When using encoding='bytes' in Py3, some **internal** keys stored as |
| 951 | # strings in Py2 are loaded as bytes. This function decodes them with |
| 952 | # ascii encoding, one that Py3 uses by default. |
| 953 | # |
| 954 | # NOTE: This should only be used on internal keys (e.g., `typename` and |
| 955 | # `location` in `persistent_load` below! |
| 956 | if isinstance(bytes_str, bytes): |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 957 | return bytes_str.decode('ascii') |
David Riazati | dca123e | 2019-11-19 10:14:44 -0800 | [diff] [blame] | 958 | return bytes_str |
| 959 | |
| 960 | |
| 961 | def _get_restore_location(map_location): |
| 962 | if map_location is None: |
| 963 | restore_location = default_restore_location |
| 964 | elif isinstance(map_location, dict): |
| 965 | def restore_location(storage, location): |
| 966 | location = map_location.get(location, location) |
| 967 | return default_restore_location(storage, location) |
| 968 | elif isinstance(map_location, _string_classes): |
| 969 | def restore_location(storage, location): |
| 970 | return default_restore_location(storage, map_location) |
| 971 | elif isinstance(map_location, torch.device): |
| 972 | def restore_location(storage, location): |
| 973 | return default_restore_location(storage, str(map_location)) |
| 974 | else: |
| 975 | def restore_location(storage, location): |
| 976 | result = map_location(storage, location) |
| 977 | if result is None: |
| 978 | result = default_restore_location(storage, location) |
| 979 | return result |
| 980 | return restore_location |
| 981 | |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 982 | class StorageType(): |
| 983 | def __init__(self, name): |
| 984 | self.dtype = _get_dtype_from_pickle_storage_type(name) |
| 985 | |
| 986 | def __str__(self): |
| 987 | return f'StorageType(dtype={self.dtype})' |
| 988 | |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 989 | def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', **pickle_load_args): |
David Riazati | dca123e | 2019-11-19 10:14:44 -0800 | [diff] [blame] | 990 | restore_location = _get_restore_location(map_location) |
| 991 | |
| 992 | loaded_storages = {} |
| 993 | |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 994 | def load_tensor(dtype, numel, key, location): |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 995 | name = f'data/{key}' |
davidriazati | da8191a | 2020-06-04 16:57:17 -0700 | [diff] [blame] | 996 | |
Kurt Mohler | 8e7fe87 | 2022-02-15 15:43:57 -0800 | [diff] [blame] | 997 | storage = zip_file.get_storage_from_record(name, numel, torch._UntypedStorage).storage()._untyped() |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 998 | # TODO: Once we decide to break serialization FC, we can |
Kurt Mohler | 8e7fe87 | 2022-02-15 15:43:57 -0800 | [diff] [blame] | 999 | # stop wrapping with _TypedStorage |
| 1000 | loaded_storages[key] = torch.storage._TypedStorage( |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 1001 | wrap_storage=restore_location(storage, location), |
| 1002 | dtype=dtype) |
David Riazati | dca123e | 2019-11-19 10:14:44 -0800 | [diff] [blame] | 1003 | |
| 1004 | def persistent_load(saved_id): |
| 1005 | assert isinstance(saved_id, tuple) |
| 1006 | typename = _maybe_decode_ascii(saved_id[0]) |
| 1007 | data = saved_id[1:] |
| 1008 | |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 1009 | assert typename == 'storage', \ |
| 1010 | f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 1011 | storage_type, key, location, numel = data |
| 1012 | dtype = storage_type.dtype |
| 1013 | |
David Riazati | dca123e | 2019-11-19 10:14:44 -0800 | [diff] [blame] | 1014 | if key not in loaded_storages: |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 1015 | nbytes = numel * torch._utils._element_size(dtype) |
| 1016 | load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location)) |
| 1017 | |
| 1018 | return loaded_storages[key] |
David Riazati | dca123e | 2019-11-19 10:14:44 -0800 | [diff] [blame] | 1019 | |
Philip Meier | b0afe94 | 2021-03-09 11:28:03 -0800 | [diff] [blame] | 1020 | load_module_mapping: Dict[str, str] = { |
| 1021 | # See https://github.com/pytorch/pytorch/pull/51633 |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 1022 | 'torch.tensor': 'torch._tensor' |
Philip Meier | b0afe94 | 2021-03-09 11:28:03 -0800 | [diff] [blame] | 1023 | } |
Brian Hirsh | 1827713 | 2021-03-04 17:08:51 -0800 | [diff] [blame] | 1024 | |
| 1025 | # Need to subclass Unpickler instead of directly monkey-patching the find_class method |
| 1026 | # because it's marked readonly in pickle. |
| 1027 | # The type: ignore is because mypy can't statically determine the type of this class. |
| 1028 | class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined] |
| 1029 | # from https://stackoverflow.com/questions/13398462/unpickling-python-objects-with-a-changed-module-path/13405732 |
| 1030 | # Lets us override the imports that pickle uses when unpickling an object. |
| 1031 | # This is useful for maintaining BC if we change a module path that tensor instantiation relies on. |
| 1032 | def find_class(self, mod_name, name): |
Kurt Mohler | 5883523 | 2021-10-05 13:48:45 -0700 | [diff] [blame] | 1033 | if type(name) is str and 'Storage' in name: |
| 1034 | try: |
| 1035 | return StorageType(name) |
| 1036 | except KeyError: |
| 1037 | pass |
Brian Hirsh | 1827713 | 2021-03-04 17:08:51 -0800 | [diff] [blame] | 1038 | mod_name = load_module_mapping.get(mod_name, mod_name) |
| 1039 | return super().find_class(mod_name, name) |
| 1040 | |
David Riazati | dca123e | 2019-11-19 10:14:44 -0800 | [diff] [blame] | 1041 | # Load the data (which may in turn use `persistent_load` to load tensors) |
Zachary DeVito | cb75add | 2020-09-22 21:15:15 -0700 | [diff] [blame] | 1042 | data_file = io.BytesIO(zip_file.get_record(pickle_file)) |
Brian Hirsh | 1827713 | 2021-03-04 17:08:51 -0800 | [diff] [blame] | 1043 | |
| 1044 | unpickler = UnpicklerWrapper(data_file, **pickle_load_args) |
David Riazati | dca123e | 2019-11-19 10:14:44 -0800 | [diff] [blame] | 1045 | unpickler.persistent_load = persistent_load |
| 1046 | result = unpickler.load() |
| 1047 | |
Wojciech Baranowski | fcadca1 | 2020-06-30 22:29:22 -0700 | [diff] [blame] | 1048 | torch._utils._validate_loaded_sparse_tensors() |
| 1049 | |
David Riazati | dca123e | 2019-11-19 10:14:44 -0800 | [diff] [blame] | 1050 | return result |
David Riazati | 8c6f0c0 | 2019-11-22 12:28:49 -0800 | [diff] [blame] | 1051 | |
| 1052 | |
| 1053 | def _is_torchscript_zip(zip_file): |
Shen Li | 1022443 | 2021-08-12 11:39:31 -0700 | [diff] [blame] | 1054 | return 'constants.pkl' in zip_file.get_all_records() |