Add warning for weights_only (#129239)
Also changes default for `weights_only` to `None` per comment below (hence the `suppress-bc-linter` tag)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129239
Approved by: https://github.com/albanD
ghstack dependencies: #129244, #129251
diff --git a/test/test_nn.py b/test/test_nn.py
index 2553db0..e8947c9 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -1802,26 +1802,35 @@
m = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)]))
with warnings.catch_warnings(record=True) as w:
m = pickle.loads(pickle.dumps(m))
- self.assertTrue(len(w) == 0)
+ # warning from torch.load call in _load_from_bytes
+ num_warnings = 2 if torch._dynamo.is_compiling() else 1
+ self.assertTrue(len(w) == num_warnings)
+ self.assertEqual(w[0].category, FutureWarning)
# Test whether loading from older checkpoints works without triggering warnings
m = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)]))
del m._forward_pre_hooks, m._state_dict_hooks, m._load_state_dict_pre_hooks, m._non_persistent_buffers_set
with warnings.catch_warnings(record=True) as w:
m = pickle.loads(pickle.dumps(m))
- self.assertTrue(len(w) == 0)
+ # warning from torch.load call in _load_from_bytes
+ self.assertTrue(len(w) == 1)
+ self.assertEqual(w[0].category, FutureWarning)
m = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))})
with warnings.catch_warnings(record=True) as w:
m = pickle.loads(pickle.dumps(m))
- self.assertTrue(len(w) == 0)
+ # warning from torch.load call in _load_from_bytes
+ self.assertTrue(len(w) == 1)
+ self.assertEqual(w[0].category, FutureWarning)
# Test whether loading from older checkpoints works without triggering warnings
m = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))})
del m._forward_pre_hooks, m._state_dict_hooks, m._load_state_dict_pre_hooks, m._non_persistent_buffers_set
with warnings.catch_warnings(record=True) as w:
m = pickle.loads(pickle.dumps(m))
- self.assertTrue(len(w) == 0)
+ # warning from torch.load call in _load_from_bytes
+ self.assertTrue(len(w) == 1)
+ self.assertEqual(w[0].category, FutureWarning)
def test_weight_norm_pickle(self):
m = torch.nn.utils.weight_norm(nn.Linear(5, 7))
diff --git a/test/test_serialization.py b/test/test_serialization.py
index a25f985..31136c6 100644
--- a/test/test_serialization.py
+++ b/test/test_serialization.py
@@ -837,7 +837,6 @@
test(f_new, f_old)
self.assertTrue(len(w) == 0, msg=f"Expected no warnings but got {[str(x) for x in w]}")
-
class TestOldSerialization(TestCase, SerializationMixin):
# unique_key is necessary because on Python 2.7, if a warning passed to
# the warning module is the same, it is not raised again.
@@ -865,7 +864,8 @@
loaded = torch.load(checkpoint)
self.assertTrue(isinstance(loaded, module.Net))
if can_retrieve_source:
- self.assertEqual(len(w), 0)
+ self.assertEqual(len(w), 1)
+ self.assertEqual(w[0].category, FutureWarning)
# Replace the module with different source
fname = get_file_path_2(os.path.dirname(os.path.dirname(torch.__file__)), 'torch', 'testing',
@@ -876,8 +876,8 @@
loaded = torch.load(checkpoint)
self.assertTrue(isinstance(loaded, module.Net))
if can_retrieve_source:
- self.assertEqual(len(w), 1)
- self.assertTrue(w[0].category, 'SourceChangeWarning')
+ self.assertEqual(len(w), 2)
+ self.assertTrue(w[1].category, 'SourceChangeWarning')
def test_serialization_container(self):
self._test_serialization_container('file', tempfile.NamedTemporaryFile)
diff --git a/torch/serialization.py b/torch/serialization.py
index d7c3fd1..91905b7 100644
--- a/torch/serialization.py
+++ b/torch/serialization.py
@@ -987,7 +987,7 @@
map_location: MAP_LOCATION = None,
pickle_module: Any = None,
*,
- weights_only: bool = False,
+ weights_only: Optional[bool] = None,
mmap: Optional[bool] = None,
**pickle_load_args: Any,
) -> Any:
@@ -1097,6 +1097,11 @@
" with `weights_only` please check the recommended steps in the following error message."
" WeightsUnpickler error: "
)
+ if weights_only is None:
+ weights_only, warn_weights_only = False, True
+ else:
+ warn_weights_only = False
+
# Add ability to force safe only weight loads via environment variable
if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in [
"1",
@@ -1113,6 +1118,20 @@
)
else:
if pickle_module is None:
+ if warn_weights_only:
+ warnings.warn(
+ "You are using `torch.load` with `weights_only=False` (the current default value), which uses "
+ "the default pickle module implicitly. It is possible to construct malicious pickle data "
+ "which will execute arbitrary code during unpickling (See "
+ "https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). "
+ "In a future release, the default value for `weights_only` will be flipped to `True`. This "
+ "limits the functions that could be executed during unpickling. Arbitrary objects will no "
+ "longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the "
+ "user via `torch.serialization.add_safe_globals`. We recommend you start setting "
+ "`weights_only=True` for any use case where you don't have full control of the loaded file. "
+ "Please open an issue on GitHub for any issues related to this experimental feature.",
+ FutureWarning,
+ )
pickle_module = pickle
# make flipping default BC-compatible