blob: cd9540382cc1a0e6b3550fa5ba8957e082d22d90 [file] [log] [blame]
# Owner(s): ["module: nn"]
import re
import unittest
from copy import deepcopy
from itertools import product
from tempfile import NamedTemporaryFile
import torch
import torch.nn as nn
from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
IS_WINDOWS,
parametrize,
run_tests,
skipIfCrossRef,
skipIfTorchDynamo,
swap,
TEST_NUMPY,
TestCase,
)
from torch.utils._pytree import tree_map
if TEST_NUMPY:
import numpy as np
class TestLoadStateDict(NNTestCase):
_do_cuda_memory_leak_check = True
_do_cuda_non_default_stream = True
@unittest.skipIf(not TEST_NUMPY, "numpy not found")
@swap([True, False])
def test_load_state_dict_invalid(self):
m = torch.nn.Linear(2, 2, bias=False)
state_dict = {"weight": np.random.randn(2, 2)}
with self.assertRaisesRegex(
RuntimeError,
"expected torch.Tensor or Tensor-like object from checkpoint but received",
):
m.load_state_dict(state_dict)
state_dict = {"weight": ((1.0, 1.0), (2.0, 2.0))}
with self.assertRaisesRegex(
RuntimeError,
"expected torch.Tensor or Tensor-like object from checkpoint but received",
):
m.load_state_dict(state_dict)
@swap([True, False])
def test_load_state_dict_type(self):
m = nn.Module()
with self.assertRaisesRegex(
TypeError, "Expected state_dict to be dict-like, got"
):
m.load_state_dict("")
with self.assertRaisesRegex(
TypeError, "Expected state_dict to be dict-like, got"
):
m.load_state_dict(2)
@swap([True, False])
@skipIfTorchDynamo("dynamo installs weakrefs on some params")
def test_load_state_dict(self):
l = nn.Linear(5, 5)
block = nn.Module()
block.conv1 = nn.Conv2d(3, 3, 3, bias=True)
block.conv2 = nn.Conv2d(3, 3, 3, bias=False)
net = nn.Module()
net.linear1 = l
net.linear2 = l
net.bn = nn.BatchNorm2d(2)
net.block = block
net.add_module("empty", None)
conv1_bias_dtype = block.conv1.bias.dtype
state_dict = net.state_dict()
state_dict.update(
{
"linear1.weight": torch.ones(5, 5),
"block.conv1.bias": torch.arange(1, 4, dtype=conv1_bias_dtype),
"bn.running_mean": torch.randn(2),
}
)
# Also test if a DDP state_dict can be loaded from a local model.
ddp_state_dict = net.state_dict()
ddp_state_dict.update(
{
"module.linear1.weight": torch.ones(5, 5),
"module.block.conv1.bias": torch.arange(1, 4, dtype=conv1_bias_dtype),
"module.bn.running_mean": torch.randn(2),
}
)
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
ddp_state_dict, "module."
)
for sd in [state_dict, ddp_state_dict]:
incompatible_keys = net.load_state_dict(sd)
self.assertEqual(len(incompatible_keys.missing_keys), 0)
self.assertEqual(len(incompatible_keys.unexpected_keys), 0)
self.assertNotIn("Incompatible", str(incompatible_keys))
self.assertEqual(net.linear1.weight, sd["linear1.weight"])
self.assertEqual(net.block.conv1.bias, sd["block.conv1.bias"])
self.assertEqual(net.bn.running_mean, sd["bn.running_mean"])
state_dict = net.state_dict()
state_dict.update({"extra": torch.ones(5)})
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
incompatible_keys = net.load_state_dict(state_dict, strict=False)
self.assertEqual(len(incompatible_keys.missing_keys), 0)
self.assertEqual(len(incompatible_keys.unexpected_keys), 1)
self.assertIn("extra", incompatible_keys.unexpected_keys)
self.assertIn("Incompatible", str(incompatible_keys))
state_dict = net.state_dict()
state_dict.update({"extra.param": torch.ones(5)})
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
incompatible_keys = net.load_state_dict(state_dict, strict=False)
self.assertEqual(len(incompatible_keys.missing_keys), 0)
self.assertEqual(len(incompatible_keys.unexpected_keys), 1)
self.assertIn("extra.param", incompatible_keys.unexpected_keys)
state_dict = net.state_dict()
del state_dict["linear1.weight"]
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
incompatible_keys = net.load_state_dict(state_dict, strict=False)
self.assertEqual(len(incompatible_keys.missing_keys), 1)
self.assertEqual(len(incompatible_keys.unexpected_keys), 0)
self.assertIn("linear1.weight", incompatible_keys.missing_keys)
state_dict.update({"extra.param": torch.ones(5)})
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
incompatible_keys = net.load_state_dict(state_dict, strict=False)
self.assertEqual(len(incompatible_keys.missing_keys), 1)
self.assertEqual(len(incompatible_keys.unexpected_keys), 1)
self.assertIn("linear1.weight", incompatible_keys.missing_keys)
self.assertIn("extra.param", incompatible_keys.unexpected_keys)
state_dict = net.state_dict()
state_dict.update({"bn.running_mean": torch.rand(14, 4)}) # wrong size
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
self.assertRaises(
RuntimeError, lambda: net.load_state_dict(state_dict, strict=False)
)
state_dict = net.state_dict()
old_state_dict = deepcopy(state_dict)
state_dict = {
"linear1.weight": torch.ones(5, 5),
"block.conv1.bias": torch.arange(1, 4, dtype=conv1_bias_dtype),
"bn.running_mean": torch.randn(2),
"nonexistent_key": torch.rand(3),
}
net.load_state_dict(state_dict, strict=False)
self.assertEqual(net.linear1.weight, state_dict["linear1.weight"])
self.assertEqual(net.block.conv1.bias, state_dict["block.conv1.bias"])
self.assertEqual(net.bn.running_mean, state_dict["bn.running_mean"])
new_state_dict = net.state_dict()
del old_state_dict["linear1.weight"]
del old_state_dict["block.conv1.bias"]
del old_state_dict["bn.running_mean"]
for (
k,
v,
) in old_state_dict.items():
self.assertTrue(v.equal(new_state_dict[k]))
@swap([True, False])
def test_load_state_dict_BC(self):
# BatchNormNd
# Added num_batches_tracked buffer at version 2. For state dict with
# earlier versions or no versions, it should provide default value of 0.
bn = nn.BatchNorm2d(3)
state_dict = bn.state_dict()
del state_dict["num_batches_tracked"]
state_dict._metadata[""]["version"] = 1 # version 1
bn.load_state_dict(state_dict)
self.assertEqual(bn.num_batches_tracked.dtype, torch.long)
self.assertEqual(bn.num_batches_tracked.item(), 0)
del state_dict._metadata[""]["version"] # no version
bn.load_state_dict(state_dict)
self.assertEqual(bn.num_batches_tracked.dtype, torch.long)
self.assertEqual(bn.num_batches_tracked.item(), 0)
@swap([True, False])
def test_load_state_dict_child(self):
base_module = nn.Linear(1, 1)
model = base_module
for _ in range(3):
model = nn.Sequential(*[deepcopy(model) for _ in range(10)])
def hook_fn(
module,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
module_state_dict = module.state_dict()
self.assertEqual(len(module_state_dict.keys()), len(state_dict.keys()))
model[0][0]._register_load_state_dict_pre_hook(hook_fn, with_module=True)
model.load_state_dict(model.state_dict(), strict=True)
@unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows")
@swap([True, False])
def test_register_state_dict_pre_hook_backward_compat(self):
called = False
def my_state_dict_pre_hook(*args, **kwargs):
nonlocal called
called = True
m = nn.Linear(1, 1)
self.assertTrue(hasattr(m, "_state_dict_pre_hooks"))
delattr(m, "_state_dict_pre_hooks")
# Save and load, ensure we can still call state_dict
# without running into issues.
with NamedTemporaryFile() as f:
# Note that torch.save / torch.load is not recommended
# to save / load modules.
torch.save(m, f.name)
m = torch.load(f.name)
# Ensure we can run state_dict without issues
_ = m.state_dict()
self.assertFalse(called)
m.register_state_dict_pre_hook(my_state_dict_pre_hook)
_ = m.state_dict()
self.assertTrue(called)
# fails swapping as LSTM installs weak references on the parameters
@swap([False])
@skipIfTorchDynamo("TorchDynamo fails here for unknown reasons")
def test_load_state_dict_ref_cycle(self):
# load_state_dict shouldn't cause a reference cycle involving Tensors
import gc
m = torch.nn.LSTM(16, 16, bidirectional=True)
gc.collect()
m.load_state_dict(deepcopy(m).state_dict())
refcycles = gc.collect()
self.assertEqual(refcycles, 0)
@swap([True, False])
def test_load_state_dict_custom(self):
class CustomState(nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.ones(1))
self.sub = torch.nn.Linear(5, 5)
def _save_to_state_dict(self, destination, prefix, keep_vars):
destination[prefix + "serialized"] = self.param.data + 1
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
# skip some of the error handling
self.param.data.copy_(state_dict[prefix + "serialized"] - 1)
# use sequential to verify nesting
m = nn.Sequential(CustomState())
with torch.no_grad():
m[0].param[0] = 10
m[0].sub.weight[0, 0] = 555
state_dict = m.state_dict()
self.assertEqual(state_dict["0.serialized"].item(), 11)
self.assertIn("0.sub.weight", state_dict)
self.assertNotIn("0.param", state_dict)
del m
mm = nn.Sequential(CustomState())
self.assertEqual(mm[0].param[0].item(), 1)
mm.load_state_dict(state_dict)
self.assertEqual(mm[0].param[0].item(), 10)
self.assertEqual(mm[0].sub.weight[0, 0].item(), 555)
@swap([True, False])
@parametrize("keep_vars", [True, False])
def test_load_state_dict_assign_meta(self, keep_vars):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(3, 5)
self.bn = nn.BatchNorm1d(5)
self.x = nn.Parameter(torch.rand(5), requires_grad=False)
def forward(self, input):
return self.x + self.bn(self.fc1(input))
swap = torch.__future__.get_swap_module_params_on_conversion()
net = MyModule()
state_dict = net.state_dict(keep_vars=keep_vars)
for v in state_dict.values():
v.requires_grad_(False)
with torch.device("meta"):
net_meta = MyModule()
net_meta_state_dict_old = net_meta.state_dict(keep_vars=True)
net_meta.load_state_dict(state_dict, assign=True)
# Make sure parameters and persistent buffers were assigned
net_meta_state_dict = net_meta.state_dict(keep_vars=True)
for key in state_dict.keys():
if key in net_meta._parameters:
if keep_vars and not swap:
# state_dict[key] is an nn.Parameter
self.assertTrue(state_dict[key] is net_meta_state_dict[key])
else:
if swap:
self.assertTrue(
net_meta_state_dict[key] is net_meta_state_dict_old[key]
)
else:
# state_dict[key] is not an nn.Parameter so it will be detached when wrapping with a Parameter
self.assertTrue(
net_meta_state_dict[key] is not net_meta_state_dict_old[key]
)
self.assertEqual(
net_meta_state_dict_old[key].requires_grad,
net_meta_state_dict[key].requires_grad,
)
self.assertEqual(
net_meta_state_dict_old[key].requires_grad,
net_meta_state_dict[key].requires_grad,
)
self.assertEqual(state_dict[key], net_meta_state_dict[key])
elif (
key in net_meta._buffers
and key not in net_meta._non_persistent_buffers_set
):
self.assertTrue(state_dict[key] is net_meta_state_dict[key])
self.assertEqual(state_dict[key], net_meta_state_dict[key])
# Make sure that ordering of parameters and buffers is preserved
net_named_parameters = net.named_parameters()
net_named_buffers = net.named_buffers()
net_meta_named_parameters = net_meta.named_parameters()
net_meta_named_buffers = net_meta.named_buffers()
for (n1, _), (n2, _) in zip(net_named_parameters, net_meta_named_parameters):
self.assertEqual(n1, n2)
for (n1, _), (n2, _) in zip(net_named_buffers, net_meta_named_buffers):
self.assertEqual(n1, n2)
# Make sure outputs are the same
t = torch.randn(4, 3)
out_net = net(t)
out_net_meta = net_meta(t.clone())
self.assertEqual(out_net, out_net_meta)
@swap([True, False])
def test_load_state_dict_assign_with_optimizer(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(3, 5)
self.bn = nn.BatchNorm1d(5)
def forward(self, input):
return self.bn(self.fc1(input))
net = MyModule()
opt = torch.optim.Adam(net.parameters(), lr=1000)
x = torch.randn(4, 3)
num_iters = 3
for i in range(num_iters):
opt.zero_grad()
out = net(x)
out.sum().backward()
opt.step()
opt_state_dict = deepcopy(opt.state_dict())
net_state_dict = deepcopy(net.state_dict())
with torch.device("meta"):
net_meta = MyModule()
net_meta.load_state_dict(net_state_dict, assign=True)
# must create optimizer only after loading state_dict when assign=True
opt2 = torch.optim.Adam(net_meta.parameters(), lr=1000)
opt2.load_state_dict(opt_state_dict)
y = x.clone()
for i in range(num_iters):
opt.zero_grad()
out = net(x)
out.sum().backward()
opt.step()
opt2.zero_grad()
out2 = net_meta(y)
out2.sum().backward()
opt2.step()
self.assertEqual(opt.state_dict(), opt2.state_dict())
self.assertEqual(net.state_dict(), net_meta.state_dict())
@swap([True, False])
def test_load_state_dict_assign_shape_stride(self):
# Assigned tensor is allowed to have different properties than initial
# tensor except for shape
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(3, 5)
self.bn = nn.BatchNorm1d(5)
def forward(self, input):
return self.bn(self.fc1(input))
net = MyModule()
state_dict = net.state_dict()
# loading should be ok if stride is different
state_dict["fc1.weight"] = torch.randn(3, 5).transpose(0, 1)
net2 = MyModule()
net2.load_state_dict(state_dict, strict=False, assign=True)
state_dict["fc1.weight"] = torch.randn(2, 4)
with self.assertRaisesRegex(
RuntimeError, "size mismatch for fc1.weight: copying a param with shape"
):
net2.load_state_dict(state_dict, strict=False, assign=True)
@swap([True, False])
def test_load_state_dict_warn_assign(self):
with torch.device("meta"):
m = torch.nn.Linear(3, 5)
state_dict = m.state_dict()
state_dict["weight"] = torch.empty_like(state_dict["weight"], device="cpu")
with self.assertWarnsRegex(
UserWarning,
"for weight: copying from a non-meta parameter in the checkpoint to a meta",
):
m.load_state_dict(state_dict)
@swap([True, False])
def test_load_state_dict_with_unexpected_key(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(5, 10)
m = MyModule()
# Unexpected key & strict = True
with self.assertRaisesRegex(RuntimeError, "Unexpected key"):
state_dict = m.state_dict()
state_dict["fc1.bad_suffix"] = torch.randn(5, 10)
m.load_state_dict(state_dict)
# Unexpected key & strict = False
state_dict = m.load_state_dict(state_dict, strict=False)
self.assertIn("fc1.bad_suffix", state_dict.unexpected_keys)
# Unexpected key whose prefix matches a valid key & strict = True
with self.assertRaisesRegex(RuntimeError, "Unexpected key"):
state_dict = m.state_dict()
state_dict["fc1.weight.bad_suffix"] = torch.randn(5, 10)
m.load_state_dict(state_dict)
# Unexpected key whose prefix matches a valid key & strict = False
state_dict = m.load_state_dict(state_dict, strict=False)
self.assertIn("fc1.weight.bad_suffix", state_dict.unexpected_keys)
def load_torch_function_handler(cls, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
def module_load(dest, src, assign=False):
if isinstance(dest, cls):
if assign:
return src.detach()
else:
if type(src) is torch.Tensor:
return cls(src)
elif type(src) is cls:
return src.detach()
else:
if isinstance(src, MyWrapperLoadTensor):
return cls(src._data)
return cls(src)
else:
assert isinstance(
src, cls
), f"Expected isinstance(src, {cls}) but got {type(src)}"
assert (
type(dest) == torch.Tensor
or type(dest) == torch.nn.Parameter
or issubclass(cls, type(dest))
)
if assign:
return src.detach()
else:
if isinstance(src, MyWrapperLoadTensor):
if type(dest) not in {torch.Tensor, torch.nn.Parameter}:
return type(dest)(src._data)
else:
return src._data.detach()
else:
return torch.Tensor(src)
if func is torch.Tensor.module_load:
return module_load(*args, **kwargs)
else:
with torch._C.DisableTorchFunctionSubclass():
# detach must return instance of same subclass for nn.Parameter()
if func == torch.Tensor.detach:
ret = func(*args, **kwargs)
if not isinstance(ret, cls):
return cls(ret)
return ret
return func(*args, **kwargs)
class MyLoadTensor(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
return load_torch_function_handler(cls, func, types, args, kwargs)
# We use MyLoadTensor2 to test tensor subclass, wrapper tensor subclass
# where neither inherits from each other
class MyLoadTensor2(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
return load_torch_function_handler(cls, func, types, args, kwargs)
class MyBrokenLoadTensor(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
if func is torch.Tensor.module_load:
# wrong as this doesn't detach!
return args[1]
else:
with torch._C.DisableTorchFunctionSubclass():
# detach must return instance of same subclass for nn.Parameter()
if func == torch.Tensor.detach:
return cls(func(*args, **kwargs))
return func(*args, **kwargs)
class MyWrapperLoadTensor(MyLoadTensor):
@staticmethod
def __new__(cls, data: torch.Tensor):
t = torch.Tensor._make_wrapper_subclass(
cls,
data.size(),
dtype=data.dtype,
layout=data.layout,
device=data.device,
requires_grad=data.requires_grad,
strides=data.stride(),
storage_offset=data.storage_offset(),
)
return t
def __init__(self, data: torch.Tensor):
self._data = data
def __repr__(self):
return f"MyWrapperLoadTensor({self._data.__repr__()})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(t):
return t._data if isinstance(t, MyWrapperLoadTensor) else t
def wrap(t):
return MyWrapperLoadTensor(t) if isinstance(t, torch.Tensor) else t
kwargs = {} if kwargs is None else kwargs
out = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
return tree_map(wrap, out)
class TestLoadStateDictSwap(TestCase):
@skipIfCrossRef
@skipIfTorchDynamo("Can't swap with dynamo as dynamo installs weakrefs")
@swap([True])
@parametrize("assign", [True, False])
def test_swap_subclass(self, assign):
def _create_model(subclass=None):
m = torch.nn.Linear(2, 3, bias=False)
m.register_buffer("buf", torch.randn(2, 3))
if subclass is not None:
m.weight = torch.nn.Parameter(subclass(m.weight))
m.buf = subclass(m.buf)
return m
def _test(m_subclass=None, sd_subclass=None):
m = _create_model(m_subclass)
sd = _create_model(sd_subclass).state_dict()
m.load_state_dict(sd, assign=assign)
self.assertEqual(m.weight, sd["weight"])
self.assertEqual(m.buf, sd["buf"])
self.assertTrue(isinstance(m.weight, torch.nn.Parameter))
self.assertTrue(not isinstance(m.buf, torch.nn.Parameter))
weight_type, buf_type = (torch.nn.Parameter, torch.Tensor)
if assign:
if sd_subclass is not None:
weight_type, buf_type = (sd_subclass, sd_subclass)
else:
if m_subclass is not None:
weight_type, buf_type = (m_subclass, m_subclass)
self.assertTrue(type(m.weight) is weight_type)
self.assertTrue(type(m.buf) is buf_type)
# (MyLoadTensor, MyWrapperLoadTensor) tests the behavior of (superclass, subclass)
subclasses = [None, MyLoadTensor, MyLoadTensor2, MyWrapperLoadTensor]
for m_s, sd_s in product(subclasses, subclasses):
_test(m_s, sd_s)
# MyBrokenLoadTensor should error since its module_load doesn't call .detach()
with self.assertRaisesRegex(
RuntimeError, re.escape("Error(s) in loading state_dict for Linear:")
):
_test(None, MyBrokenLoadTensor)
instantiate_parametrized_tests(TestLoadStateDict)
instantiate_parametrized_tests(TestLoadStateDictSwap)
if __name__ == "__main__":
TestCase._default_dtype_check_enabled = True
run_tests()