blob: 3cc92349bbc52f909059144e02ef078c20e95bf3 [file] [log] [blame]
Sam Grosse3e786e2016-11-03 16:29:14 -04001import difflib
Zsolt Dollensteinb0043072021-08-12 10:56:55 -07002import os
Shen Li10224432021-08-12 11:39:31 -07003import io
Adam Paszke686e8d32016-08-22 22:11:50 -04004import shutil
5import struct
Sam Grosse3e786e2016-11-03 16:29:14 -04006import sys
Shen Li10224432021-08-12 11:39:31 -07007import torch
Sam Grosse3e786e2016-11-03 16:29:14 -04008import tarfile
9import tempfile
10import warnings
Adam Paszke686e8d32016-08-22 22:11:50 -040011from contextlib import closing, contextmanager
Shen Li10224432021-08-12 11:39:31 -070012from ._utils import _import_dotted_name
13from ._six import string_classes as _string_classes
Zhengxu Chene62189a2021-08-05 14:19:56 -070014from torch._sources import get_source_lines_and_file
Nikita Shulga591fffc2020-07-02 07:09:21 -070015from torch.types import Storage
Kurt Mohler58835232021-10-05 13:48:45 -070016from torch.storage import _get_dtype_from_pickle_storage_type
Shen Li10224432021-08-12 11:39:31 -070017from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union, IO
18import copyreg
19import pickle
20import pathlib
Adam Paszke686e8d32016-08-22 22:11:50 -040021
Adam Paszke75579fc2016-08-23 07:52:58 -070022DEFAULT_PROTOCOL = 2
23
Shen Li10224432021-08-12 11:39:31 -070024LONG_SIZE = struct.Struct('=l').size
25INT_SIZE = struct.Struct('=i').size
26SHORT_SIZE = struct.Struct('=h').size
Adam Paszke686e8d32016-08-22 22:11:50 -040027
Shen Li10224432021-08-12 11:39:31 -070028MAGIC_NUMBER = 0x1950a86a20f9469cfc6c
Adam Lerere71cf202017-02-22 16:24:20 -050029PROTOCOL_VERSION = 1001
Shen Li10224432021-08-12 11:39:31 -070030STORAGE_KEY_SEPARATOR = ','
Adam Lerere71cf202017-02-22 16:24:20 -050031
Sam Grosse3e786e2016-11-03 16:29:14 -040032class SourceChangeWarning(Warning):
33 pass
34
35
Adam Paszke686e8d32016-08-22 22:11:50 -040036@contextmanager
37def mkdtemp():
38 path = tempfile.mkdtemp()
39 yield path
40 shutil.rmtree(path)
41
42
Adam Paszke0c9670d2016-10-04 11:33:00 -070043_package_registry = []
44
45
Nikita Shulga591fffc2020-07-02 07:09:21 -070046def _is_zipfile(f) -> bool:
davidriazati7a921ba2019-08-30 16:43:45 -070047 # This is a stricter implementation than zipfile.is_zipfile().
48 # zipfile.is_zipfile() is True if the magic number appears anywhere in the
49 # binary. Since we expect the files here to be generated by torch.save or
50 # torch.jit.save, it's safe to only check the start bytes and avoid
davidriazati74ce3a02020-02-05 15:30:21 -080051 # collisions and assume the zip has only 1 file.
52 # See bugs.python.org/issue28494.
davidriazati7a921ba2019-08-30 16:43:45 -070053
54 # Read the first 4 bytes of the file
55 read_bytes = []
56 start = f.tell()
57
davidriazati7a921ba2019-08-30 16:43:45 -070058 byte = f.read(1)
59 while byte != "":
60 read_bytes.append(byte)
61 if len(read_bytes) == 4:
62 break
63 byte = f.read(1)
64 f.seek(start)
65
Shen Li10224432021-08-12 11:39:31 -070066 local_header_magic_number = [b'P', b'K', b'\x03', b'\x04']
davidriazati74ce3a02020-02-05 15:30:21 -080067 return read_bytes == local_header_magic_number
davidriazati7a921ba2019-08-30 16:43:45 -070068
69
Adam Paszke0c9670d2016-10-04 11:33:00 -070070def register_package(priority, tagger, deserializer):
71 queue_elem = (priority, tagger, deserializer)
72 _package_registry.append(queue_elem)
73 _package_registry.sort()
74
75
Shen Li10224432021-08-12 11:39:31 -070076def check_module_version_greater_or_equal(module, req_version_tuple, error_if_malformed=True):
77 '''
Kurt Mohler36947492019-12-18 08:03:39 -080078 Check if a module's version satisfies requirements
79
80 Usually, a module's version string will be like 'x.y.z', which would be represented
81 as a tuple (x, y, z), but sometimes it could be an unexpected format. If the version
82 string does not match the given tuple's format up to the length of the tuple, then
83 error and exit or emit a warning.
84
85 Args:
86 module: the module to check the version of
87 req_version_tuple: tuple (usually of ints) representing the required version
88 error_if_malformed: whether we should exit if module version string is malformed
89
90 Returns:
91 requirement_is_met: bool
Shen Li10224432021-08-12 11:39:31 -070092 '''
Kurt Mohler36947492019-12-18 08:03:39 -080093 try:
Shen Li10224432021-08-12 11:39:31 -070094 version_strs = module.__version__.split('.')
Kurt Mohler36947492019-12-18 08:03:39 -080095 # Cast module version fields to match the types of the required version
96 module_version = tuple(
Shen Li10224432021-08-12 11:39:31 -070097 type(req_field)(version_strs[idx]) for idx, req_field in enumerate(req_version_tuple)
Kurt Mohler36947492019-12-18 08:03:39 -080098 )
99 requirement_is_met = module_version >= req_version_tuple
100
101 except Exception as e:
102 message = (
103 "'%s' module version string is malformed '%s' and cannot be compared"
104 " with tuple %s"
Shen Li10224432021-08-12 11:39:31 -0700105 ) % (
106 module.__name__, module.__version__, str(req_version_tuple)
107 )
Kurt Mohler36947492019-12-18 08:03:39 -0800108 if error_if_malformed:
Akihiro Nittaf17d7a52020-08-31 19:28:48 -0700109 raise RuntimeError(message) from e
Kurt Mohler36947492019-12-18 08:03:39 -0800110 else:
Shen Li10224432021-08-12 11:39:31 -0700111 warnings.warn(message + ', but continuing assuming that requirement is met')
Kurt Mohler36947492019-12-18 08:03:39 -0800112 requirement_is_met = True
113
114 return requirement_is_met
115
116
Adam Paszke0c9670d2016-10-04 11:33:00 -0700117def _cpu_tag(obj):
Shen Li10224432021-08-12 11:39:31 -0700118 if type(obj).__module__ == 'torch':
119 return 'cpu'
Adam Paszke0c9670d2016-10-04 11:33:00 -0700120
121
122def _cuda_tag(obj):
Shen Li10224432021-08-12 11:39:31 -0700123 if type(obj).__module__ == 'torch.cuda':
124 return 'cuda:' + str(obj.get_device())
Adam Paszke0c9670d2016-10-04 11:33:00 -0700125
126
127def _cpu_deserialize(obj, location):
Shen Li10224432021-08-12 11:39:31 -0700128 if location == 'cpu':
Adam Paszke0c9670d2016-10-04 11:33:00 -0700129 return obj
130
131
Lu Fange0f68672018-12-03 14:07:50 -0800132def validate_cuda_device(location):
Sameer Deshmukh2f5eefe2020-01-07 10:27:47 -0800133 device = torch.cuda._utils._get_device_index(location, True)
Lu Fange0f68672018-12-03 14:07:50 -0800134
135 if not torch.cuda.is_available():
Shen Li10224432021-08-12 11:39:31 -0700136 raise RuntimeError('Attempting to deserialize object on a CUDA '
137 'device but torch.cuda.is_available() is False. '
138 'If you are running on a CPU-only machine, '
139 'please use torch.load with map_location=torch.device(\'cpu\') '
140 'to map your storages to the CPU.')
Nikita Shulga0c01f132020-09-04 07:36:47 -0700141 device_count = torch.cuda.device_count()
142 if device >= device_count:
Shen Li10224432021-08-12 11:39:31 -0700143 raise RuntimeError('Attempting to deserialize object on CUDA device '
144 f'{device} but torch.cuda.device_count() is {device_count}. Please use '
145 'torch.load with map_location to map your storages '
146 'to an existing device.')
Lu Fange0f68672018-12-03 14:07:50 -0800147 return device
148
149
Adam Paszke0c9670d2016-10-04 11:33:00 -0700150def _cuda_deserialize(obj, location):
Shen Li10224432021-08-12 11:39:31 -0700151 if location.startswith('cuda'):
Lu Fange0f68672018-12-03 14:07:50 -0800152 device = validate_cuda_device(location)
Luca Wehrstedt29f4f8f2019-02-21 01:24:56 -0800153 if getattr(obj, "_torch_load_uninitialized", False):
154 storage_type = getattr(torch.cuda, type(obj).__name__)
155 with torch.cuda.device(device):
Kurt Mohler58835232021-10-05 13:48:45 -0700156 return storage_type(obj.nbytes())
Luca Wehrstedt29f4f8f2019-02-21 01:24:56 -0800157 else:
158 return obj.cuda(device)
Adam Paszke0c9670d2016-10-04 11:33:00 -0700159
160
161register_package(10, _cpu_tag, _cpu_deserialize)
162register_package(20, _cuda_tag, _cuda_deserialize)
163
164
Kurt Mohler8e7fe872022-02-15 15:43:57 -0800165def location_tag(storage: Union[Storage, torch.storage._TypedStorage]):
Adam Paszke0c9670d2016-10-04 11:33:00 -0700166 for _, tagger, _ in _package_registry:
167 location = tagger(storage)
168 if location:
169 return location
Shen Li10224432021-08-12 11:39:31 -0700170 raise RuntimeError("don't know how to determine data location of "
171 + torch.typename(storage))
Adam Paszke0c9670d2016-10-04 11:33:00 -0700172
173
174def default_restore_location(storage, location):
175 for _, _, fn in _package_registry:
176 result = fn(storage, location)
177 if result is not None:
178 return result
Shen Li10224432021-08-12 11:39:31 -0700179 raise RuntimeError("don't know how to restore data location of "
180 + torch.typename(storage) + " (tagged with "
181 + location + ")")
Adam Paszke0c9670d2016-10-04 11:33:00 -0700182
183
184def normalize_storage_type(storage_type):
185 return getattr(torch, storage_type.__name__)
186
187
188def storage_to_tensor_type(storage):
189 storage_type = type(storage)
190 module = _import_dotted_name(storage_type.__module__)
Shen Li10224432021-08-12 11:39:31 -0700191 return getattr(module, storage_type.__name__.replace('Storage', 'Tensor'))
Adam Paszke0c9670d2016-10-04 11:33:00 -0700192
193
Your Namebfedace2019-11-14 13:35:37 -0800194def _is_path(name_or_buffer):
Shen Li10224432021-08-12 11:39:31 -0700195 return isinstance(name_or_buffer, str) or \
196 isinstance(name_or_buffer, pathlib.Path)
Your Namefff4f162019-11-06 18:40:10 -0800197
Your Namefff4f162019-11-06 18:40:10 -0800198
Your Namebfedace2019-11-14 13:35:37 -0800199class _opener(object):
200 def __init__(self, file_like):
201 self.file_like = file_like
Your Namefff4f162019-11-06 18:40:10 -0800202
203 def __enter__(self):
Your Namefff4f162019-11-06 18:40:10 -0800204 return self.file_like
205
206 def __exit__(self, *args):
Your Namebfedace2019-11-14 13:35:37 -0800207 pass
208
209
210class _open_file(_opener):
211 def __init__(self, name, mode):
212 super(_open_file, self).__init__(open(name, mode))
213
214 def __exit__(self, *args):
215 self.file_like.close()
216
217
218class _open_buffer_reader(_opener):
David Riazatidca123e2019-11-19 10:14:44 -0800219 def __init__(self, buffer):
220 super(_open_buffer_reader, self).__init__(buffer)
221 _check_seekable(buffer)
Your Namebfedace2019-11-14 13:35:37 -0800222
223
224class _open_buffer_writer(_opener):
225 def __exit__(self, *args):
226 self.file_like.flush()
227
228
229def _open_file_like(name_or_buffer, mode):
230 if _is_path(name_or_buffer):
231 return _open_file(name_or_buffer, mode)
232 else:
Shen Li10224432021-08-12 11:39:31 -0700233 if 'w' in mode:
Your Namebfedace2019-11-14 13:35:37 -0800234 return _open_buffer_writer(name_or_buffer)
Shen Li10224432021-08-12 11:39:31 -0700235 elif 'r' in mode:
Your Namebfedace2019-11-14 13:35:37 -0800236 return _open_buffer_reader(name_or_buffer)
237 else:
Nikita Shulga0c01f132020-09-04 07:36:47 -0700238 raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}")
Your Namebfedace2019-11-14 13:35:37 -0800239
240
Your Namebfedace2019-11-14 13:35:37 -0800241class _open_zipfile_reader(_opener):
Nikita Shulga591fffc2020-07-02 07:09:21 -0700242 def __init__(self, name_or_buffer) -> None:
Shen Li10224432021-08-12 11:39:31 -0700243 super(_open_zipfile_reader, self).__init__(torch._C.PyTorchFileReader(name_or_buffer))
Your Namebfedace2019-11-14 13:35:37 -0800244
245
246class _open_zipfile_writer_file(_opener):
Nikita Shulga591fffc2020-07-02 07:09:21 -0700247 def __init__(self, name) -> None:
Shen Li10224432021-08-12 11:39:31 -0700248 super(_open_zipfile_writer_file, self).__init__(torch._C.PyTorchFileWriter(str(name)))
Your Namebfedace2019-11-14 13:35:37 -0800249
Nikita Shulga591fffc2020-07-02 07:09:21 -0700250 def __exit__(self, *args) -> None:
Your Namebfedace2019-11-14 13:35:37 -0800251 self.file_like.write_end_of_file()
252
253
254class _open_zipfile_writer_buffer(_opener):
Nikita Shulga591fffc2020-07-02 07:09:21 -0700255 def __init__(self, buffer) -> None:
Your Namebfedace2019-11-14 13:35:37 -0800256 self.buffer = buffer
Shen Li10224432021-08-12 11:39:31 -0700257 super(_open_zipfile_writer_buffer, self).__init__(torch._C.PyTorchFileWriter(buffer))
Your Namebfedace2019-11-14 13:35:37 -0800258
Nikita Shulga591fffc2020-07-02 07:09:21 -0700259 def __exit__(self, *args) -> None:
Your Namebfedace2019-11-14 13:35:37 -0800260 self.file_like.write_end_of_file()
261 self.buffer.flush()
262
263
264def _open_zipfile_writer(name_or_buffer):
Nikita Shulga591fffc2020-07-02 07:09:21 -0700265 container: Type[_opener]
Your Namebfedace2019-11-14 13:35:37 -0800266 if _is_path(name_or_buffer):
267 container = _open_zipfile_writer_file
268 else:
269 container = _open_zipfile_writer_buffer
270 return container(name_or_buffer)
Edward Z. Yang57eb8bd2017-08-31 08:46:30 -0700271
272
Nikita Shulga591fffc2020-07-02 07:09:21 -0700273def _is_compressed_file(f) -> bool:
Shen Li10224432021-08-12 11:39:31 -0700274 compress_modules = ['gzip']
li-roybafec162018-05-31 12:06:38 -0700275 try:
276 return f.__module__ in compress_modules
277 except AttributeError:
278 return False
279
280
281def _should_read_directly(f):
282 """
283 Checks if f is a file that should be read directly. It should be read
284 directly if it is backed by a real file (has a fileno) and is not a
285 a compressed file (e.g. gzip)
286 """
287 if _is_compressed_file(f):
288 return False
Richard Zou8ba87132018-03-08 22:18:55 -0500289 try:
290 return f.fileno() >= 0
291 except io.UnsupportedOperation:
292 return False
293 except AttributeError:
294 return False
295
296
Nikita Shulga591fffc2020-07-02 07:09:21 -0700297def _check_seekable(f) -> bool:
Shen Li10224432021-08-12 11:39:31 -0700298
Wei Yangc3e4b3c2018-06-12 12:57:28 -0700299 def raise_err_msg(patterns, e):
300 for p in patterns:
301 if p in str(e):
Shen Li10224432021-08-12 11:39:31 -0700302 msg = (str(e) + ". You can only torch.load from a file that is seekable."
303 + " Please pre-load the data into a buffer like io.BytesIO and"
304 + " try to load from it instead.")
Wei Yangc3e4b3c2018-06-12 12:57:28 -0700305 raise type(e)(msg)
306 raise e
307
308 try:
309 f.seek(f.tell())
310 return True
311 except (io.UnsupportedOperation, AttributeError) as e:
312 raise_err_msg(["seek", "tell"], e)
Nikita Shulga591fffc2020-07-02 07:09:21 -0700313 return False
Wei Yangc3e4b3c2018-06-12 12:57:28 -0700314
Nikita Shulga591fffc2020-07-02 07:09:21 -0700315def _check_dill_version(pickle_module) -> None:
Shen Li10224432021-08-12 11:39:31 -0700316 '''Checks if using dill as the pickle module, and if so, checks if it is the correct version.
Kurt Mohler36947492019-12-18 08:03:39 -0800317 If dill version is lower than 0.3.1, a ValueError is raised.
318
319 Args:
320 pickle_module: module used for pickling metadata and objects
321
Shen Li10224432021-08-12 11:39:31 -0700322 '''
323 if pickle_module.__name__ == 'dill':
Kurt Mohler36947492019-12-18 08:03:39 -0800324 required_dill_version = (0, 3, 1)
Shen Li10224432021-08-12 11:39:31 -0700325 if not check_module_version_greater_or_equal(pickle_module, required_dill_version, False):
326 raise ValueError((
327 "'torch' supports dill >= %s, but you have dill %s."
328 " Please upgrade dill or switch to 'pickle'"
329 ) % (
330 '.'.join([str(num) for num in required_dill_version]),
331 pickle_module.__version__
332 ))
Wei Yangc3e4b3c2018-06-12 12:57:28 -0700333
Shen Li10224432021-08-12 11:39:31 -0700334def save(obj, f: Union[str, os.PathLike, BinaryIO, IO[bytes]],
335 pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) -> None:
Yukio Siraichi9d544752021-04-27 10:56:41 -0700336 # Reference: https://github.com/pytorch/pytorch/issues/54354
337 # The first line of this docstring overrides the one Sphinx generates for the
338 # documentation. We need it so that Sphinx doesn't leak `pickle`s path from
339 # the build environment (e.g. `<module 'pickle' from '/leaked/path').
340
341 """save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)
342
343 Saves an object to a disk file.
Adam Paszke0c9670d2016-10-04 11:33:00 -0700344
Jeff Yang47525162021-03-29 10:01:52 -0700345 See also: :ref:`saving-loading-tensors`
Eli Stevensb87c1132017-02-26 05:33:26 -0800346
Adam Paszke0c9670d2016-10-04 11:33:00 -0700347 Args:
348 obj: saved object
James Reed3ecae992020-06-30 10:05:57 -0700349 f: a file-like object (has to implement write and flush) or a string or
350 os.PathLike object containing a file name
Adam Paszke0c9670d2016-10-04 11:33:00 -0700351 pickle_module: module used for pickling metadata and objects
352 pickle_protocol: can be specified to override the default protocol
Richard Zou8ba87132018-03-08 22:18:55 -0500353
KushajveerSingh88fe05e2020-06-05 12:55:39 -0700354 .. note::
355 A common PyTorch convention is to save tensors using .pt file extension.
356
James Reed0d24ed02020-06-22 18:37:33 -0700357 .. note::
mattip75155df2020-07-07 11:37:06 -0700358 PyTorch preserves storage sharing across serialization. See
Jeff Yang47525162021-03-29 10:01:52 -0700359 :ref:`preserve-storage-sharing` for more details.
Ailing Zhangd7cd1682020-06-29 17:21:47 -0700360
361 .. note::
James Reed0d24ed02020-06-22 18:37:33 -0700362 The 1.6 release of PyTorch switched ``torch.save`` to use a new
363 zipfile-based file format. ``torch.load`` still retains the ability to
364 load files in the old format. If for any reason you want ``torch.save``
365 to use the old format, pass the kwarg ``_use_new_zipfile_serialization=False``.
366
Richard Zou8ba87132018-03-08 22:18:55 -0500367 Example:
Vishwak Srinivasan76a283d2018-03-13 19:17:43 +0530368 >>> # Save to file
li-royd564ecb2018-04-21 04:35:37 -0700369 >>> x = torch.tensor([0, 1, 2, 3, 4])
Richard Zou8ba87132018-03-08 22:18:55 -0500370 >>> torch.save(x, 'tensor.pt')
Vishwak Srinivasan76a283d2018-03-13 19:17:43 +0530371 >>> # Save to io.BytesIO buffer
Richard Zou8ba87132018-03-08 22:18:55 -0500372 >>> buffer = io.BytesIO()
373 >>> torch.save(x, buffer)
Adam Paszke0c9670d2016-10-04 11:33:00 -0700374 """
Kurt Mohler36947492019-12-18 08:03:39 -0800375 _check_dill_version(pickle_module)
376
Shen Li10224432021-08-12 11:39:31 -0700377 with _open_file_like(f, 'wb') as opened_file:
James Reed3ecae992020-06-30 10:05:57 -0700378 if _use_new_zipfile_serialization:
379 with _open_zipfile_writer(opened_file) as opened_zipfile:
380 _save(obj, opened_zipfile, pickle_module, pickle_protocol)
381 return
David Riazatidca123e2019-11-19 10:14:44 -0800382 _legacy_save(obj, opened_file, pickle_module, pickle_protocol)
Adam Paszkee867baa2016-11-01 13:22:01 +0100383
384
Nikita Shulga591fffc2020-07-02 07:09:21 -0700385def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
Sam Grosse3e786e2016-11-03 16:29:14 -0400386 import torch.nn as nn
Sam Grosse3e786e2016-11-03 16:29:14 -0400387 serialized_container_types = {}
Adam Lerere71cf202017-02-22 16:24:20 -0500388 serialized_storages = {}
Adam Paszke686e8d32016-08-22 22:11:50 -0400389
Kurt Mohlerbc3d3802021-11-16 08:41:14 -0800390 # Since loading storages that view the same data with different dtypes is
391 # not supported, we need to keep track of the dtype associated with each
392 # storage data_ptr and throw an error if the dtype is ever different.
393 # TODO: This feature could be added in the future
394 storage_dtypes: Dict[int, torch.dtype] = {}
395
Nikita Shulga591fffc2020-07-02 07:09:21 -0700396 def persistent_id(obj: Any) -> Optional[Tuple]:
Adam Lerere71cf202017-02-22 16:24:20 -0500397 # FIXME: the docs say that persistent_id should only return a string
398 # but torch store returns tuples. This works only in the binary protocol
399 # see
400 # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
401 # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
Adam Paszked6fa3b32017-01-16 21:04:05 +0100402 if isinstance(obj, type) and issubclass(obj, nn.Module):
Sam Grosse3e786e2016-11-03 16:29:14 -0400403 if obj in serialized_container_types:
404 return None
405 serialized_container_types[obj] = True
Adam Paszke2bd7a3c2016-12-01 20:35:35 +0100406 source_file = source = None
407 try:
Dmytro Dzhulgakovdf338f82019-09-14 21:26:23 -0700408 source_lines, _, source_file = get_source_lines_and_file(obj)
Shen Li10224432021-08-12 11:39:31 -0700409 source = ''.join(source_lines)
Sam Gross8e581352017-10-23 23:03:37 -0400410 except Exception: # saving the source is optional, so we can ignore any errors
Shen Li10224432021-08-12 11:39:31 -0700411 warnings.warn("Couldn't retrieve source code for container of "
412 "type " + obj.__name__ + ". It won't be checked "
413 "for correctness upon loading.")
414 return ('module', obj, source_file, source)
Geeta Chauhan9e314f52019-11-04 23:15:30 -0800415
Kurt Mohler8e7fe872022-02-15 15:43:57 -0800416 if isinstance(obj, torch.storage._TypedStorage) or torch.is_storage(obj):
417 if isinstance(obj, torch.storage._TypedStorage):
Kurt Mohler58835232021-10-05 13:48:45 -0700418 # TODO: Once we decide to break serialization FC, this case
419 # can be deleted
420 storage = obj._storage
Kurt Mohlerbc3d3802021-11-16 08:41:14 -0800421 storage_dtype = obj.dtype
Kurt Mohler58835232021-10-05 13:48:45 -0700422 storage_type_str = obj.pickle_storage_type()
423 storage_type = getattr(torch, storage_type_str)
424 dtype = obj.dtype
425 storage_numel = obj.size()
426
427 else:
428 storage = obj
Kurt Mohlerbc3d3802021-11-16 08:41:14 -0800429 storage_dtype = storage.dtype
Kurt Mohler58835232021-10-05 13:48:45 -0700430 storage_type = normalize_storage_type(type(obj))
431 dtype = torch.uint8
432 storage_numel = cast(Storage, storage).nbytes()
433
Kurt Mohlerb69155f2021-11-24 09:50:11 -0800434 # If storage is allocated, ensure that any other saved storages
435 # pointing to the same data all have the same dtype. If storage is
436 # not allocated, don't perform this check
437 if storage.data_ptr() != 0:
438 if storage.data_ptr() in storage_dtypes:
439 if storage_dtype != storage_dtypes[storage.data_ptr()]:
440 raise RuntimeError(
441 'Cannot save multiple tensors or storages that '
442 'view the same data as different types')
443 else:
444 storage_dtypes[storage.data_ptr()] = storage_dtype
Kurt Mohlerbc3d3802021-11-16 08:41:14 -0800445
Nikita Shulga591fffc2020-07-02 07:09:21 -0700446 view_metadata: Optional[Tuple[str, int, int]]
Kurt Mohler58835232021-10-05 13:48:45 -0700447 storage = cast(Storage, storage)
448
Edward Yang976f9252018-07-16 15:17:57 -0700449 # Offset is always 0, but we keep it for backwards compatibility
450 # with the old serialization format (which supported storage views)
451 offset = 0
Kurt Mohler58835232021-10-05 13:48:45 -0700452 storage_key = str(storage._cdata)
453 location = location_tag(storage)
454
455 # TODO: There's an issue here with FC. It might be impossible to
456 # solve, but it's worth noting. Imagine we save a list `[storage,
457 # tensor]`, where `tensor.storage()` is the same as `storage`, and
458 # `tensor.element_size() > 1`. Let's say that `tensor.dtype ==
459 # torch.float`. The storage will be serialized with element size
460 # of 1, since we're choosing to serialize the first occurance of
461 # a duplicate storage. Since this legacy serialization format saves
462 # the numel of the storage, rather than nbytes directly, we'll be
463 # effectively saving nbytes in this case. We'll be able to load it
464 # and the tensor back up with no problems in _this_ and future
465 # versions of pytorch, but in older versions, here's the problem:
Kurt Mohler8e7fe872022-02-15 15:43:57 -0800466 # the storage will be loaded up as a _UntypedStorage, and then the
467 # FloatTensor will loaded and the _UntypedStorage will be assigned to
Kurt Mohler58835232021-10-05 13:48:45 -0700468 # it. Since the storage dtype does not match the tensor dtype, this
469 # will cause an error. If we reverse the list, like `[tensor,
470 # storage]`, then we will save the `tensor.storage()` as a faked
471 # `FloatStorage`, and the saved size will be the correct
472 # dtype-specific numel count that old versions expect. `tensor`
473 # will be able to load up properly in old versions, pointing to
474 # a FloatStorage. However, `storage` is still being translated to
Kurt Mohler8e7fe872022-02-15 15:43:57 -0800475 # a _UntypedStorage, and it will try to resolve to the same
Kurt Mohler58835232021-10-05 13:48:45 -0700476 # FloatStorage that `tensor` contains. This will also cause an
477 # error. It doesn't seem like there's any way around this.
478 # Probably, we just cannot maintain FC for the legacy format if the
479 # saved list contains both a tensor and a storage that point to the
480 # same data. We should still be able to maintain FC for lists of
481 # just tensors, as long as all views share the same dtype as the
482 # tensor they are viewing.
483
484 if storage_key not in serialized_storages:
485 serialized_storages[storage_key] = (storage, dtype)
486 is_view = storage._cdata != storage._cdata
Adam Lerere71cf202017-02-22 16:24:20 -0500487 if is_view:
Kurt Mohler58835232021-10-05 13:48:45 -0700488 view_metadata = (str(storage._cdata), offset, storage.nbytes())
Adam Lerere71cf202017-02-22 16:24:20 -0500489 else:
490 view_metadata = None
491
Kurt Mohler58835232021-10-05 13:48:45 -0700492 res = ('storage',
493 storage_type,
494 storage_key,
495 location,
496 storage_numel,
497 view_metadata)
498 return res
Adam Paszke686e8d32016-08-22 22:11:50 -0400499 return None
500
Adam Lerere71cf202017-02-22 16:24:20 -0500501 sys_info = dict(
502 protocol_version=PROTOCOL_VERSION,
Shen Li10224432021-08-12 11:39:31 -0700503 little_endian=sys.byteorder == 'little',
Adam Lerere71cf202017-02-22 16:24:20 -0500504 type_sizes=dict(
505 short=SHORT_SIZE,
506 int=INT_SIZE,
507 long=LONG_SIZE,
508 ),
509 )
Adam Paszke686e8d32016-08-22 22:11:50 -0400510
Adam Lerere71cf202017-02-22 16:24:20 -0500511 pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol)
512 pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol)
513 pickle_module.dump(sys_info, f, protocol=pickle_protocol)
514 pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
515 pickler.persistent_id = persistent_id
516 pickler.dump(obj)
Adam Paszke686e8d32016-08-22 22:11:50 -0400517
Adam Lerere71cf202017-02-22 16:24:20 -0500518 serialized_storage_keys = sorted(serialized_storages.keys())
519 pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol)
520 f.flush()
521 for key in serialized_storage_keys:
Kurt Mohler58835232021-10-05 13:48:45 -0700522 storage, dtype = serialized_storages[key]
523 storage._write_file(f, _should_read_directly(f), True, torch._utils._element_size(dtype))
Adam Paszke686e8d32016-08-22 22:11:50 -0400524
525
David Riazatidca123e2019-11-19 10:14:44 -0800526def _save(obj, zip_file, pickle_module, pickle_protocol):
527 serialized_storages = {}
Francesco Casalegnofea38242021-05-10 11:48:36 -0700528 id_map: Dict[int, str] = {}
David Riazatidca123e2019-11-19 10:14:44 -0800529
Kurt Mohlerbc3d3802021-11-16 08:41:14 -0800530 # Since loading storages that view the same data with different dtypes is
531 # not supported, we need to keep track of the dtype associated with each
532 # storage data_ptr and throw an error if the dtype is ever different.
533 # TODO: This feature could be added in the future
534 storage_dtypes: Dict[int, torch.dtype] = {}
535
David Riazatidca123e2019-11-19 10:14:44 -0800536 def persistent_id(obj):
537 # FIXME: the docs say that persistent_id should only return a string
538 # but torch store returns tuples. This works only in the binary protocol
539 # see
540 # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
541 # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
Kurt Mohler8e7fe872022-02-15 15:43:57 -0800542 if isinstance(obj, torch.storage._TypedStorage) or torch.is_storage(obj):
Kurt Mohler58835232021-10-05 13:48:45 -0700543
Kurt Mohler8e7fe872022-02-15 15:43:57 -0800544 if isinstance(obj, torch.storage._TypedStorage):
Kurt Mohler58835232021-10-05 13:48:45 -0700545 # TODO: Once we decide to break serialization FC, this case
546 # can be deleted
547 storage = obj._storage
Kurt Mohlerbc3d3802021-11-16 08:41:14 -0800548 storage_dtype = obj.dtype
Kurt Mohler58835232021-10-05 13:48:45 -0700549 storage_type_str = obj.pickle_storage_type()
550 storage_type = getattr(torch, storage_type_str)
551 storage_numel = obj.size()
552
553 else:
554 storage = obj
Kurt Mohlerbc3d3802021-11-16 08:41:14 -0800555 storage_dtype = storage.dtype
Kurt Mohler58835232021-10-05 13:48:45 -0700556 storage_type = normalize_storage_type(type(obj))
557 storage_numel = storage.nbytes()
558
559 storage = cast(Storage, storage)
Kurt Mohlerbc3d3802021-11-16 08:41:14 -0800560
Kurt Mohlerb69155f2021-11-24 09:50:11 -0800561 # If storage is allocated, ensure that any other saved storages
562 # pointing to the same data all have the same dtype. If storage is
563 # not allocated, don't perform this check
564 if storage.data_ptr() != 0:
565 if storage.data_ptr() in storage_dtypes:
566 if storage_dtype != storage_dtypes[storage.data_ptr()]:
567 raise RuntimeError(
568 'Cannot save multiple tensors or storages that '
569 'view the same data as different types')
570 else:
571 storage_dtypes[storage.data_ptr()] = storage_dtype
Kurt Mohlerbc3d3802021-11-16 08:41:14 -0800572
Kurt Mohler58835232021-10-05 13:48:45 -0700573 storage_key = id_map.setdefault(storage._cdata, str(len(id_map)))
574 location = location_tag(storage)
575 serialized_storages[storage_key] = storage
David Riazatidca123e2019-11-19 10:14:44 -0800576
Shen Li10224432021-08-12 11:39:31 -0700577 return ('storage',
578 storage_type,
Kurt Mohler58835232021-10-05 13:48:45 -0700579 storage_key,
Shen Li10224432021-08-12 11:39:31 -0700580 location,
Kurt Mohler58835232021-10-05 13:48:45 -0700581 storage_numel)
582
David Riazatidca123e2019-11-19 10:14:44 -0800583 return None
584
585 # Write the pickle data for `obj`
586 data_buf = io.BytesIO()
587 pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
588 pickler.persistent_id = persistent_id
589 pickler.dump(obj)
590 data_value = data_buf.getvalue()
Shen Li10224432021-08-12 11:39:31 -0700591 zip_file.write_record('data.pkl', data_value, len(data_value))
David Riazatidca123e2019-11-19 10:14:44 -0800592
593 # Write each tensor to a file named tensor/the_tensor_key in the zip archive
594 for key in sorted(serialized_storages.keys()):
Shen Li10224432021-08-12 11:39:31 -0700595 name = f'data/{key}'
David Riazatidca123e2019-11-19 10:14:44 -0800596 storage = serialized_storages[key]
Thomas Viehmann7b7f2512020-10-13 12:48:22 -0700597 # given that we copy things around anyway, we might use storage.cpu()
598 # this means to that to get tensors serialized, you need to implement
599 # .cpu() on the underlying Storage
Shen Li10224432021-08-12 11:39:31 -0700600 if storage.device.type != 'cpu':
Thomas Viehmann7b7f2512020-10-13 12:48:22 -0700601 storage = storage.cpu()
602 # Now that it is on the CPU we can directly copy it into the zip file
Kurt Mohler58835232021-10-05 13:48:45 -0700603 num_bytes = storage.nbytes()
Thomas Viehmann7b7f2512020-10-13 12:48:22 -0700604 zip_file.write_record(name, storage.data_ptr(), num_bytes)
David Riazatidca123e2019-11-19 10:14:44 -0800605
606
SsnL54d5c532018-12-10 08:05:06 -0800607def load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
Yukio Siraichi9d544752021-04-27 10:56:41 -0700608 # Reference: https://github.com/pytorch/pytorch/issues/54354
609 # The first line of this docstring overrides the one Sphinx generates for the
610 # documentation. We need it so that Sphinx doesn't leak `pickle`s path from
611 # the build environment (e.g. `<module 'pickle' from '/leaked/path').
612
613 """load(f, map_location=None, pickle_module=pickle, **pickle_load_args)
614
615 Loads an object saved with :func:`torch.save` from a file.
Adam Paszke0c9670d2016-10-04 11:33:00 -0700616
Tongzhou Wang51ee0482019-06-13 12:08:52 -0700617 :func:`torch.load` uses Python's unpickling facilities but treats storages,
greaber490d5c22017-10-14 19:54:53 +0300618 which underlie tensors, specially. They are first deserialized on the
619 CPU and are then moved to the device they were saved from. If this fails
620 (e.g. because the run time system doesn't have certain devices), an exception
621 is raised. However, storages can be dynamically remapped to an alternative
Tongzhou Wang51ee0482019-06-13 12:08:52 -0700622 set of devices using the :attr:`map_location` argument.
Adam Paszke0c9670d2016-10-04 11:33:00 -0700623
Tongzhou Wang51ee0482019-06-13 12:08:52 -0700624 If :attr:`map_location` is a callable, it will be called once for each serialized
greaber490d5c22017-10-14 19:54:53 +0300625 storage with two arguments: storage and location. The storage argument
626 will be the initial deserialization of the storage, residing on the CPU.
627 Each serialized storage has a location tag associated with it which
628 identifies the device it was saved from, and this tag is the second
Tongzhou Wang51ee0482019-06-13 12:08:52 -0700629 argument passed to :attr:`map_location`. The builtin location tags are ``'cpu'``
630 for CPU tensors and ``'cuda:device_id'`` (e.g. ``'cuda:2'``) for CUDA tensors.
631 :attr:`map_location` should return either ``None`` or a storage. If
632 :attr:`map_location` returns a storage, it will be used as the final deserialized
633 object, already moved to the right device. Otherwise, :func:`torch.load` will
634 fall back to the default behavior, as if :attr:`map_location` wasn't specified.
greaber490d5c22017-10-14 19:54:53 +0300635
Yash293fa5f2020-02-21 09:27:19 -0800636 If :attr:`map_location` is a :class:`torch.device` object or a string containing
Tongzhou Wang51ee0482019-06-13 12:08:52 -0700637 a device tag, it indicates the location where all tensors should be loaded.
Adam Paszke8307f212017-12-15 17:50:20 -0500638
Tongzhou Wang51ee0482019-06-13 12:08:52 -0700639 Otherwise, if :attr:`map_location` is a dict, it will be used to remap location tags
greaber490d5c22017-10-14 19:54:53 +0300640 appearing in the file (keys), to ones that specify where to put the
641 storages (values).
642
643 User extensions can register their own location tags and tagging and
Tongzhou Wang51ee0482019-06-13 12:08:52 -0700644 deserialization methods using :func:`torch.serialization.register_package`.
Adam Paszke0c9670d2016-10-04 11:33:00 -0700645
646 Args:
Zain Patelbbeee482020-12-14 19:14:02 -0800647 f: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`),
James Reed3ecae992020-06-30 10:05:57 -0700648 or a string or os.PathLike object containing a file name
Tongzhou Wang51ee0482019-06-13 12:08:52 -0700649 map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
Leonid Vlasenkov46a868d2017-07-10 17:24:54 +0300650 locations
651 pickle_module: module used for unpickling metadata and objects (has to
Tongzhou Wang51ee0482019-06-13 12:08:52 -0700652 match the :attr:`pickle_module` used to serialize file)
nmilosev5fc52482019-09-25 14:56:44 -0700653 pickle_load_args: (Python 3 only) optional keyword arguments passed over to
Tongzhou Wang51ee0482019-06-13 12:08:52 -0700654 :func:`pickle_module.load` and :func:`pickle_module.Unpickler`, e.g.,
nmilosev5fc52482019-09-25 14:56:44 -0700655 :attr:`errors=...`.
Sam Grossc4d13182017-03-15 16:54:19 -0400656
Edgar Andrés Margffoy Tuay90a259e2020-01-25 22:10:49 -0800657 .. warning::
davidriazati74ce3a02020-02-05 15:30:21 -0800658 :func:`torch.load()` uses ``pickle`` module implicitly, which is known to be insecure.
Edgar Andrés Margffoy Tuay90a259e2020-01-25 22:10:49 -0800659 It is possible to construct malicious pickle data which will execute arbitrary code
660 during unpickling. Never load data that could have come from an untrusted
661 source, or that could have been tampered with. **Only load data you trust**.
662
Ailing Zhangd7934732018-06-29 16:42:46 -0700663 .. note::
Tongzhou Wang51ee0482019-06-13 12:08:52 -0700664 When you call :func:`torch.load()` on a file which contains GPU tensors, those tensors
665 will be loaded to GPU by default. You can call ``torch.load(.., map_location='cpu')``
Ailing Zhangd7934732018-06-29 16:42:46 -0700666 and then :meth:`load_state_dict` to avoid GPU RAM surge when loading a model checkpoint.
667
SsnL54d5c532018-12-10 08:05:06 -0800668 .. note::
nmilosev5fc52482019-09-25 14:56:44 -0700669 By default, we decode byte strings as ``utf-8``. This is to avoid a common error
670 case ``UnicodeDecodeError: 'ascii' codec can't decode byte 0x...``
671 when loading files saved by Python 2 in Python 3. If this default
672 is incorrect, you may use an extra :attr:`encoding` keyword argument to specify how
Tongzhou Wang51ee0482019-06-13 12:08:52 -0700673 these objects should be loaded, e.g., :attr:`encoding='latin1'` decodes them
674 to strings using ``latin1`` encoding, and :attr:`encoding='bytes'` keeps them
SsnL54d5c532018-12-10 08:05:06 -0800675 as byte arrays which can be decoded later with ``byte_array.decode(...)``.
676
Sam Grossc4d13182017-03-15 16:54:19 -0400677 Example:
678 >>> torch.load('tensors.pt')
679 # Load all tensors onto the CPU
Ethan Steinberg9fa1dff2018-05-10 12:50:00 -0700680 >>> torch.load('tensors.pt', map_location=torch.device('cpu'))
Adam Paszke8307f212017-12-15 17:50:20 -0500681 # Load all tensors onto the CPU, using a function
Sam Grossc4d13182017-03-15 16:54:19 -0400682 >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
greaber490d5c22017-10-14 19:54:53 +0300683 # Load all tensors onto GPU 1
684 >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
Sam Grossc4d13182017-03-15 16:54:19 -0400685 # Map tensors from GPU 1 to GPU 0
686 >>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
Richard Zou8ba87132018-03-08 22:18:55 -0500687 # Load tensor from io.BytesIO object
Thomas Viehmann6a6983e2019-01-29 11:19:51 -0800688 >>> with open('tensor.pt', 'rb') as f:
Sam Estepc147aa32021-01-20 15:50:50 -0800689 ... buffer = io.BytesIO(f.read())
Richard Zou8ba87132018-03-08 22:18:55 -0500690 >>> torch.load(buffer)
nmilosev5fc52482019-09-25 14:56:44 -0700691 # Load a module with 'ascii' encoding for unpickling
692 >>> torch.load('module.pt', encoding='ascii')
Adam Paszke0c9670d2016-10-04 11:33:00 -0700693 """
Kurt Mohler36947492019-12-18 08:03:39 -0800694 _check_dill_version(pickle_module)
695
Shen Li10224432021-08-12 11:39:31 -0700696 if 'encoding' not in pickle_load_args.keys():
697 pickle_load_args['encoding'] = 'utf-8'
Your Namefff4f162019-11-06 18:40:10 -0800698
Shen Li10224432021-08-12 11:39:31 -0700699 with _open_file_like(f, 'rb') as opened_file:
David Riazatidca123e2019-11-19 10:14:44 -0800700 if _is_zipfile(opened_file):
James Reed9c82b572020-07-06 08:58:28 -0700701 # The zipfile reader is going to advance the current file position.
702 # If we want to actually tail call to torch.jit.load, we need to
703 # reset back to the original position.
704 orig_position = opened_file.tell()
James Reed3ecae992020-06-30 10:05:57 -0700705 with _open_zipfile_reader(opened_file) as opened_zipfile:
David Riazati8c6f0c02019-11-22 12:28:49 -0800706 if _is_torchscript_zip(opened_zipfile):
Shen Li10224432021-08-12 11:39:31 -0700707 warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive"
708 " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to"
709 " silence this warning)", UserWarning)
James Reed9c82b572020-07-06 08:58:28 -0700710 opened_file.seek(orig_position)
711 return torch.jit.load(opened_file)
Shen Li10224432021-08-12 11:39:31 -0700712 return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
713 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
Adam Paszkee867baa2016-11-01 13:22:01 +0100714
715
Pearu Petersonb7fb2b82019-10-04 08:07:44 -0700716# Register pickling support for layout instances such as
717# torch.sparse_coo, etc
718def _get_layout(name):
Shen Li10224432021-08-12 11:39:31 -0700719 """Get layout extension object from its string representation.
720 """
721 cache = _get_layout.cache # type: ignore[attr-defined]
Pearu Petersonb7fb2b82019-10-04 08:07:44 -0700722 if not cache:
723 for v in torch.__dict__.values():
724 if isinstance(v, torch.layout):
725 cache[str(v)] = v
726 return cache[name]
727
Nikita Shulga591fffc2020-07-02 07:09:21 -0700728# There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087
Shen Li10224432021-08-12 11:39:31 -0700729_get_layout.cache = {} # type: ignore[attr-defined]
Pearu Petersonb7fb2b82019-10-04 08:07:44 -0700730copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),)))
731
732
David Riazatidca123e2019-11-19 10:14:44 -0800733def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
Nikita Shulga591fffc2020-07-02 07:09:21 -0700734 deserialized_objects: Dict[int, Any] = {}
Adam Paszke686e8d32016-08-22 22:11:50 -0400735
David Riazatidca123e2019-11-19 10:14:44 -0800736 restore_location = _get_restore_location(map_location)
Adam Paszke0c9670d2016-10-04 11:33:00 -0700737
Kurt Mohler58835232021-10-05 13:48:45 -0700738 class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined]
739
740 def find_class(self, mod_name, name):
741 if type(name) is str and 'Storage' in name:
742 try:
743 return StorageType(name)
744 except KeyError:
745 pass
746 return super().find_class(mod_name, name)
747
Sam Grosse3e786e2016-11-03 16:29:14 -0400748 def _check_container_source(container_type, source_file, original_source):
Soumith Chintala5b142e52018-02-20 17:42:57 -0500749 try:
Shen Li10224432021-08-12 11:39:31 -0700750 current_source = ''.join(get_source_lines_and_file(container_type)[0])
Soumith Chintala5b142e52018-02-20 17:42:57 -0500751 except Exception: # saving the source is optional, so we can ignore any errors
Shen Li10224432021-08-12 11:39:31 -0700752 warnings.warn("Couldn't retrieve source code for container of "
753 "type " + container_type.__name__ + ". It won't be checked "
754 "for correctness upon loading.")
Soumith Chintala5b142e52018-02-20 17:42:57 -0500755 return
Sam Grosse3e786e2016-11-03 16:29:14 -0400756 if original_source != current_source:
757 if container_type.dump_patches:
Shen Li10224432021-08-12 11:39:31 -0700758 file_name = container_type.__name__ + '.patch'
759 diff = difflib.unified_diff(current_source.split('\n'),
760 original_source.split('\n'),
761 source_file,
762 source_file, lineterm="")
763 lines = '\n'.join(diff)
Sam Grosse3e786e2016-11-03 16:29:14 -0400764 try:
Shen Li10224432021-08-12 11:39:31 -0700765 with open(file_name, 'a+') as f:
Sam Grosse3e786e2016-11-03 16:29:14 -0400766 file_size = f.seek(0, 2)
767 f.seek(0)
768 if file_size == 0:
769 f.write(lines)
770 elif file_size != len(lines) or f.read() != lines:
771 raise IOError
Shen Li10224432021-08-12 11:39:31 -0700772 msg = ("Saved a reverse patch to " + file_name + ". "
773 "Run `patch -p0 < " + file_name + "` to revert your "
774 "changes.")
Sam Grosse3e786e2016-11-03 16:29:14 -0400775 except IOError:
Shen Li10224432021-08-12 11:39:31 -0700776 msg = ("Tried to save a patch, but couldn't create a "
777 "writable file " + file_name + ". Make sure it "
778 "doesn't exist and your working directory is "
779 "writable.")
Sam Grosse3e786e2016-11-03 16:29:14 -0400780 else:
Shen Li10224432021-08-12 11:39:31 -0700781 msg = ("you can retrieve the original source code by "
782 "accessing the object's source attribute or set "
783 "`torch.nn.Module.dump_patches = True` and use the "
784 "patch tool to revert the changes.")
Nikita Shulga0c01f132020-09-04 07:36:47 -0700785 msg = f"source code of class '{torch.typename(container_type)}' has changed. {msg}"
Sam Grosse3e786e2016-11-03 16:29:14 -0400786 warnings.warn(msg, SourceChangeWarning)
787
Adam Lerere71cf202017-02-22 16:24:20 -0500788 def legacy_load(f):
Nikita Shulga591fffc2020-07-02 07:09:21 -0700789 deserialized_objects: Dict[int, Any] = {}
Adam Lerere71cf202017-02-22 16:24:20 -0500790
791 def persistent_load(saved_id):
792 if isinstance(saved_id, tuple):
793 # Ignore containers that don't have any sources saved
794 if all(saved_id[1:]):
795 _check_container_source(*saved_id)
796 return saved_id[0]
797 return deserialized_objects[int(saved_id)]
798
Shen Li10224432021-08-12 11:39:31 -0700799 with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \
800 mkdtemp() as tmpdir:
Adam Lerere71cf202017-02-22 16:24:20 -0500801
Shen Li10224432021-08-12 11:39:31 -0700802 tar.extract('storages', path=tmpdir)
803 with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f:
SsnL54d5c532018-12-10 08:05:06 -0800804 num_storages = pickle_module.load(f, **pickle_load_args)
Adam Lerere71cf202017-02-22 16:24:20 -0500805 for i in range(num_storages):
SsnL54d5c532018-12-10 08:05:06 -0800806 args = pickle_module.load(f, **pickle_load_args)
Adam Lerere71cf202017-02-22 16:24:20 -0500807 key, location, storage_type = args
Kurt Mohler58835232021-10-05 13:48:45 -0700808 dtype = storage_type.dtype
Kurt Mohler8e7fe872022-02-15 15:43:57 -0800809 obj = cast(Storage, torch._UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype))
Adam Lerere71cf202017-02-22 16:24:20 -0500810 obj = restore_location(obj, location)
Kurt Mohler58835232021-10-05 13:48:45 -0700811 # TODO: Once we decide to break serialization FC, we can
Kurt Mohler8e7fe872022-02-15 15:43:57 -0800812 # stop wrapping with _TypedStorage
813 deserialized_objects[key] = torch.storage._TypedStorage(
Kurt Mohler58835232021-10-05 13:48:45 -0700814 wrap_storage=obj,
815 dtype=dtype)
Adam Lerere71cf202017-02-22 16:24:20 -0500816
SsnL54d5c532018-12-10 08:05:06 -0800817 storage_views = pickle_module.load(f, **pickle_load_args)
Kurt Mohler58835232021-10-05 13:48:45 -0700818 for target_cdata, root_cdata, offset, numel in storage_views:
Adam Lerere71cf202017-02-22 16:24:20 -0500819 root = deserialized_objects[root_cdata]
Kurt Mohler58835232021-10-05 13:48:45 -0700820 element_size = torch._utils._element_size(root.dtype)
821 offset_bytes = offset * element_size
822 # TODO: Once we decide to break serialization FC, we can
Kurt Mohler8e7fe872022-02-15 15:43:57 -0800823 # stop wrapping with _TypedStorage
824 deserialized_objects[target_cdata] = torch.storage._TypedStorage(
Kurt Mohler58835232021-10-05 13:48:45 -0700825 wrap_storage=root._storage[offset_bytes:offset_bytes + numel * element_size],
826 dtype=root.dtype)
Adam Lerere71cf202017-02-22 16:24:20 -0500827
Shen Li10224432021-08-12 11:39:31 -0700828 tar.extract('tensors', path=tmpdir)
829 with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f:
SsnL54d5c532018-12-10 08:05:06 -0800830 num_tensors = pickle_module.load(f, **pickle_load_args)
Sam Gross94a0c722017-12-05 11:24:54 -0500831 for _ in range(num_tensors):
SsnL54d5c532018-12-10 08:05:06 -0800832 args = pickle_module.load(f, **pickle_load_args)
Adam Lerere71cf202017-02-22 16:24:20 -0500833 key, storage_id, original_tensor_type = args
834 storage = deserialized_objects[storage_id]
Shen Li10224432021-08-12 11:39:31 -0700835 ndim, = struct.unpack('<i', f.read(4))
Sam Gross94a0c722017-12-05 11:24:54 -0500836 # skip next 4 bytes; legacy encoding treated ndim as 8 bytes
837 f.read(4)
Kurt Mohler58835232021-10-05 13:48:45 -0700838 numel = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
Shen Li10224432021-08-12 11:39:31 -0700839 stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
840 storage_offset, = struct.unpack('<q', f.read(8))
Kurt Mohler58835232021-10-05 13:48:45 -0700841 tensor = torch.tensor([], dtype=storage.dtype).set_(
842 storage._storage, storage_offset, numel, stride)
Adam Lerere71cf202017-02-22 16:24:20 -0500843 deserialized_objects[key] = tensor
844
Shen Li10224432021-08-12 11:39:31 -0700845 pickle_file = tar.extractfile('pickle')
Kurt Mohler58835232021-10-05 13:48:45 -0700846 unpickler = UnpicklerWrapper(pickle_file, **pickle_load_args)
Adam Lerere71cf202017-02-22 16:24:20 -0500847 unpickler.persistent_load = persistent_load
848 result = unpickler.load()
849 return result
850
851 deserialized_objects = {}
852
Adam Paszke686e8d32016-08-22 22:11:50 -0400853 def persistent_load(saved_id):
Adam Lerere71cf202017-02-22 16:24:20 -0500854 assert isinstance(saved_id, tuple)
David Riazatidca123e2019-11-19 10:14:44 -0800855 typename = _maybe_decode_ascii(saved_id[0])
Adam Lerere71cf202017-02-22 16:24:20 -0500856 data = saved_id[1:]
857
Shen Li10224432021-08-12 11:39:31 -0700858 if typename == 'module':
Adam Paszke2bd7a3c2016-12-01 20:35:35 +0100859 # Ignore containers that don't have any sources saved
Adam Lerere71cf202017-02-22 16:24:20 -0500860 if all(data[1:]):
861 _check_container_source(*data)
862 return data[0]
Shen Li10224432021-08-12 11:39:31 -0700863 elif typename == 'storage':
Kurt Mohler58835232021-10-05 13:48:45 -0700864 storage_type, root_key, location, numel, view_metadata = data
David Riazatidca123e2019-11-19 10:14:44 -0800865 location = _maybe_decode_ascii(location)
Kurt Mohler58835232021-10-05 13:48:45 -0700866 dtype = storage_type.dtype
867
868 nbytes = numel * torch._utils._element_size(dtype)
869
Adam Lerere71cf202017-02-22 16:24:20 -0500870 if root_key not in deserialized_objects:
Kurt Mohler8e7fe872022-02-15 15:43:57 -0800871 obj = cast(Storage, torch._UntypedStorage(nbytes))
Luca Wehrstedt29f4f8f2019-02-21 01:24:56 -0800872 obj._torch_load_uninitialized = True
Kurt Mohler58835232021-10-05 13:48:45 -0700873 # TODO: Once we decide to break serialization FC, we can
Kurt Mohler8e7fe872022-02-15 15:43:57 -0800874 # stop wrapping with _TypedStorage
875 deserialized_objects[root_key] = torch.storage._TypedStorage(
Kurt Mohler58835232021-10-05 13:48:45 -0700876 wrap_storage=restore_location(obj, location),
877 dtype=dtype)
878
879 typed_storage = deserialized_objects[root_key]
Adam Lerere71cf202017-02-22 16:24:20 -0500880 if view_metadata is not None:
881 view_key, offset, view_size = view_metadata
Kurt Mohler58835232021-10-05 13:48:45 -0700882 offset_bytes = offset * torch._utils._element_size(dtype)
883 view_size_bytes = view_size * torch._utils._element_size(dtype)
Adam Lerere71cf202017-02-22 16:24:20 -0500884 if view_key not in deserialized_objects:
Kurt Mohler58835232021-10-05 13:48:45 -0700885 # TODO: Once we decide to break serialization FC, we can
Kurt Mohler8e7fe872022-02-15 15:43:57 -0800886 # stop wrapping with _TypedStorage
887 deserialized_objects[view_key] = torch.storage._TypedStorage(
Kurt Mohler58835232021-10-05 13:48:45 -0700888 wrap_storage=typed_storage._storage[offset_bytes:offset_bytes + view_size_bytes],
889 dtype=dtype)
890 res = deserialized_objects[view_key]
891
Adam Lerere71cf202017-02-22 16:24:20 -0500892 else:
Kurt Mohler58835232021-10-05 13:48:45 -0700893 res = typed_storage
894 return res
Adam Lerere71cf202017-02-22 16:24:20 -0500895 else:
896 raise RuntimeError("Unknown saved id type: %s" % saved_id[0])
Adam Paszke686e8d32016-08-22 22:11:50 -0400897
Wei Yangc3e4b3c2018-06-12 12:57:28 -0700898 _check_seekable(f)
li-roybafec162018-05-31 12:06:38 -0700899 f_should_read_directly = _should_read_directly(f)
Wei Yangc3e4b3c2018-06-12 12:57:28 -0700900
li-roybafec162018-05-31 12:06:38 -0700901 if f_should_read_directly and f.tell() == 0:
Richard Zou8ba87132018-03-08 22:18:55 -0500902 # legacy_load requires that f has fileno()
Philipp Langc4b0db52017-11-17 20:21:37 +0000903 # only if offset is zero we can attempt the legacy tar file loader
904 try:
905 return legacy_load(f)
906 except tarfile.TarError:
davidriazati7a921ba2019-08-30 16:43:45 -0700907 if _is_zipfile(f):
David Riazati692898f2018-12-28 13:52:01 -0800908 # .zip is used for torch.jit.save and will throw an un-pickling error here
olramded770fbc2020-01-02 12:45:14 -0800909 raise RuntimeError(
Shen Li10224432021-08-12 11:39:31 -0700910 f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)") from None
Philipp Langc4b0db52017-11-17 20:21:37 +0000911 # if not a tarfile, reset file offset and proceed
Richard Zou8ba87132018-03-08 22:18:55 -0500912 f.seek(0)
Adam Paszke686e8d32016-08-22 22:11:50 -0400913
Shen Li10224432021-08-12 11:39:31 -0700914 if not hasattr(f, 'readinto') and (3, 8, 0) <= sys.version_info < (3, 8, 2):
Nathan Goldbaum84101f32020-02-26 21:12:42 -0800915 raise RuntimeError(
916 "torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. "
Shen Li10224432021-08-12 11:39:31 -0700917 f"Received object of type \"{type(f)}\". Please update to Python 3.8.2 or newer to restore this "
918 "functionality.")
Nathan Goldbaum84101f32020-02-26 21:12:42 -0800919
SsnL54d5c532018-12-10 08:05:06 -0800920 magic_number = pickle_module.load(f, **pickle_load_args)
Adam Lerere71cf202017-02-22 16:24:20 -0500921 if magic_number != MAGIC_NUMBER:
922 raise RuntimeError("Invalid magic number; corrupt file?")
SsnL54d5c532018-12-10 08:05:06 -0800923 protocol_version = pickle_module.load(f, **pickle_load_args)
Adam Lerere71cf202017-02-22 16:24:20 -0500924 if protocol_version != PROTOCOL_VERSION:
925 raise RuntimeError("Invalid protocol version: %s" % protocol_version)
Adam Paszke686e8d32016-08-22 22:11:50 -0400926
SsnL54d5c532018-12-10 08:05:06 -0800927 _sys_info = pickle_module.load(f, **pickle_load_args)
Kurt Mohler58835232021-10-05 13:48:45 -0700928 unpickler = UnpicklerWrapper(f, **pickle_load_args)
Adam Lerere71cf202017-02-22 16:24:20 -0500929 unpickler.persistent_load = persistent_load
930 result = unpickler.load()
Adam Paszke4a8a1852016-09-24 11:38:12 -0700931
SsnL54d5c532018-12-10 08:05:06 -0800932 deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
Adam Paszke686e8d32016-08-22 22:11:50 -0400933
li-roybafec162018-05-31 12:06:38 -0700934 offset = f.tell() if f_should_read_directly else None
Adam Lerere71cf202017-02-22 16:24:20 -0500935 for key in deserialized_storage_keys:
936 assert key in deserialized_objects
Kurt Mohler58835232021-10-05 13:48:45 -0700937 typed_storage = deserialized_objects[key]
938 typed_storage._storage._set_from_file(
939 f, offset, f_should_read_directly,
940 torch._utils._element_size(typed_storage.dtype))
Philipp Langf23fb662019-05-09 08:16:20 -0700941 if offset is not None:
942 offset = f.tell()
Adam Lerere71cf202017-02-22 16:24:20 -0500943
Wojciech Baranowskifcadca12020-06-30 22:29:22 -0700944 torch._utils._validate_loaded_sparse_tensors()
945
Adam Lerere71cf202017-02-22 16:24:20 -0500946 return result
David Riazatidca123e2019-11-19 10:14:44 -0800947
948
Nikita Shulga591fffc2020-07-02 07:09:21 -0700949def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
David Riazatidca123e2019-11-19 10:14:44 -0800950 # When using encoding='bytes' in Py3, some **internal** keys stored as
951 # strings in Py2 are loaded as bytes. This function decodes them with
952 # ascii encoding, one that Py3 uses by default.
953 #
954 # NOTE: This should only be used on internal keys (e.g., `typename` and
955 # `location` in `persistent_load` below!
956 if isinstance(bytes_str, bytes):
Shen Li10224432021-08-12 11:39:31 -0700957 return bytes_str.decode('ascii')
David Riazatidca123e2019-11-19 10:14:44 -0800958 return bytes_str
959
960
961def _get_restore_location(map_location):
962 if map_location is None:
963 restore_location = default_restore_location
964 elif isinstance(map_location, dict):
965 def restore_location(storage, location):
966 location = map_location.get(location, location)
967 return default_restore_location(storage, location)
968 elif isinstance(map_location, _string_classes):
969 def restore_location(storage, location):
970 return default_restore_location(storage, map_location)
971 elif isinstance(map_location, torch.device):
972 def restore_location(storage, location):
973 return default_restore_location(storage, str(map_location))
974 else:
975 def restore_location(storage, location):
976 result = map_location(storage, location)
977 if result is None:
978 result = default_restore_location(storage, location)
979 return result
980 return restore_location
981
Kurt Mohler58835232021-10-05 13:48:45 -0700982class StorageType():
983 def __init__(self, name):
984 self.dtype = _get_dtype_from_pickle_storage_type(name)
985
986 def __str__(self):
987 return f'StorageType(dtype={self.dtype})'
988
Shen Li10224432021-08-12 11:39:31 -0700989def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', **pickle_load_args):
David Riazatidca123e2019-11-19 10:14:44 -0800990 restore_location = _get_restore_location(map_location)
991
992 loaded_storages = {}
993
Kurt Mohler58835232021-10-05 13:48:45 -0700994 def load_tensor(dtype, numel, key, location):
Shen Li10224432021-08-12 11:39:31 -0700995 name = f'data/{key}'
davidriazatida8191a2020-06-04 16:57:17 -0700996
Kurt Mohler8e7fe872022-02-15 15:43:57 -0800997 storage = zip_file.get_storage_from_record(name, numel, torch._UntypedStorage).storage()._untyped()
Kurt Mohler58835232021-10-05 13:48:45 -0700998 # TODO: Once we decide to break serialization FC, we can
Kurt Mohler8e7fe872022-02-15 15:43:57 -0800999 # stop wrapping with _TypedStorage
1000 loaded_storages[key] = torch.storage._TypedStorage(
Kurt Mohler58835232021-10-05 13:48:45 -07001001 wrap_storage=restore_location(storage, location),
1002 dtype=dtype)
David Riazatidca123e2019-11-19 10:14:44 -08001003
1004 def persistent_load(saved_id):
1005 assert isinstance(saved_id, tuple)
1006 typename = _maybe_decode_ascii(saved_id[0])
1007 data = saved_id[1:]
1008
Shen Li10224432021-08-12 11:39:31 -07001009 assert typename == 'storage', \
1010 f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
Kurt Mohler58835232021-10-05 13:48:45 -07001011 storage_type, key, location, numel = data
1012 dtype = storage_type.dtype
1013
David Riazatidca123e2019-11-19 10:14:44 -08001014 if key not in loaded_storages:
Kurt Mohler58835232021-10-05 13:48:45 -07001015 nbytes = numel * torch._utils._element_size(dtype)
1016 load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
1017
1018 return loaded_storages[key]
David Riazatidca123e2019-11-19 10:14:44 -08001019
Philip Meierb0afe942021-03-09 11:28:03 -08001020 load_module_mapping: Dict[str, str] = {
1021 # See https://github.com/pytorch/pytorch/pull/51633
Shen Li10224432021-08-12 11:39:31 -07001022 'torch.tensor': 'torch._tensor'
Philip Meierb0afe942021-03-09 11:28:03 -08001023 }
Brian Hirsh18277132021-03-04 17:08:51 -08001024
1025 # Need to subclass Unpickler instead of directly monkey-patching the find_class method
1026 # because it's marked readonly in pickle.
1027 # The type: ignore is because mypy can't statically determine the type of this class.
1028 class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined]
1029 # from https://stackoverflow.com/questions/13398462/unpickling-python-objects-with-a-changed-module-path/13405732
1030 # Lets us override the imports that pickle uses when unpickling an object.
1031 # This is useful for maintaining BC if we change a module path that tensor instantiation relies on.
1032 def find_class(self, mod_name, name):
Kurt Mohler58835232021-10-05 13:48:45 -07001033 if type(name) is str and 'Storage' in name:
1034 try:
1035 return StorageType(name)
1036 except KeyError:
1037 pass
Brian Hirsh18277132021-03-04 17:08:51 -08001038 mod_name = load_module_mapping.get(mod_name, mod_name)
1039 return super().find_class(mod_name, name)
1040
David Riazatidca123e2019-11-19 10:14:44 -08001041 # Load the data (which may in turn use `persistent_load` to load tensors)
Zachary DeVitocb75add2020-09-22 21:15:15 -07001042 data_file = io.BytesIO(zip_file.get_record(pickle_file))
Brian Hirsh18277132021-03-04 17:08:51 -08001043
1044 unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
David Riazatidca123e2019-11-19 10:14:44 -08001045 unpickler.persistent_load = persistent_load
1046 result = unpickler.load()
1047
Wojciech Baranowskifcadca12020-06-30 22:29:22 -07001048 torch._utils._validate_loaded_sparse_tensors()
1049
David Riazatidca123e2019-11-19 10:14:44 -08001050 return result
David Riazati8c6f0c02019-11-22 12:28:49 -08001051
1052
1053def _is_torchscript_zip(zip_file):
Shen Li10224432021-08-12 11:39:31 -07001054 return 'constants.pkl' in zip_file.get_all_records()