blob: 4f0e480c9b8c50e59d4f4b771456d6f00117678f [file] [log] [blame]
# Owner(s): ["module: nn"]
from copy import deepcopy
from tempfile import NamedTemporaryFile
import unittest
import torch
import torch.nn as nn
from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import TestCase, \
TEST_NUMPY, IS_WINDOWS, skipIfTorchDynamo, instantiate_parametrized_tests, \
run_tests
if TEST_NUMPY:
import numpy as np
class TestLoadStateDict(NNTestCase):
_do_cuda_memory_leak_check = True
_do_cuda_non_default_stream = True
@unittest.skipIf(not TEST_NUMPY, "numpy not found")
def test_load_state_dict_invalid(self):
m = torch.nn.Linear(2, 2, bias=False)
state_dict = {'weight': np.random.randn(2, 2)}
with self.assertRaisesRegex(RuntimeError,
"expected torch.Tensor or Tensor-like object from checkpoint but received"):
m.load_state_dict(state_dict)
state_dict = {'weight': ((1., 1.), (2., 2.))}
with self.assertRaisesRegex(RuntimeError,
"expected torch.Tensor or Tensor-like object from checkpoint but received"):
m.load_state_dict(state_dict)
def test_load_state_dict_type(self):
m = nn.Module()
with self.assertRaisesRegex(TypeError,
"Expected state_dict to be dict-like, got"):
m.load_state_dict("")
with self.assertRaisesRegex(TypeError,
"Expected state_dict to be dict-like, got"):
m.load_state_dict(2)
def test_load_state_dict(self):
l = nn.Linear(5, 5)
block = nn.Module()
block.conv1 = nn.Conv2d(3, 3, 3, bias=True)
block.conv2 = nn.Conv2d(3, 3, 3, bias=False)
net = nn.Module()
net.linear1 = l
net.linear2 = l
net.bn = nn.BatchNorm2d(2)
net.block = block
net.add_module('empty', None)
conv1_bias_dtype = block.conv1.bias.dtype
state_dict = net.state_dict()
state_dict.update({
'linear1.weight': torch.ones(5, 5),
'block.conv1.bias': torch.arange(1, 4, dtype=conv1_bias_dtype),
'bn.running_mean': torch.randn(2),
})
# Also test if a DDP state_dict can be loaded from a local model.
ddp_state_dict = net.state_dict()
ddp_state_dict.update({
'module.linear1.weight': torch.ones(5, 5),
'module.block.conv1.bias': torch.arange(1, 4, dtype=conv1_bias_dtype),
'module.bn.running_mean': torch.randn(2),
})
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(ddp_state_dict, 'module.')
for sd in [state_dict, ddp_state_dict]:
incompatible_keys = net.load_state_dict(sd)
self.assertEqual(len(incompatible_keys.missing_keys), 0)
self.assertEqual(len(incompatible_keys.unexpected_keys), 0)
self.assertNotIn('Incompatible', str(incompatible_keys))
self.assertEqual(net.linear1.weight, sd['linear1.weight'])
self.assertEqual(net.block.conv1.bias, sd['block.conv1.bias'])
self.assertEqual(net.bn.running_mean, sd['bn.running_mean'])
state_dict = net.state_dict()
state_dict.update({'extra': torch.ones(5)})
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
incompatible_keys = net.load_state_dict(state_dict, strict=False)
self.assertEqual(len(incompatible_keys.missing_keys), 0)
self.assertEqual(len(incompatible_keys.unexpected_keys), 1)
self.assertIn('extra', incompatible_keys.unexpected_keys)
self.assertIn('Incompatible', str(incompatible_keys))
state_dict = net.state_dict()
state_dict.update({'extra.param': torch.ones(5)})
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
incompatible_keys = net.load_state_dict(state_dict, strict=False)
self.assertEqual(len(incompatible_keys.missing_keys), 0)
self.assertEqual(len(incompatible_keys.unexpected_keys), 1)
self.assertIn('extra.param', incompatible_keys.unexpected_keys)
state_dict = net.state_dict()
del state_dict['linear1.weight']
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
incompatible_keys = net.load_state_dict(state_dict, strict=False)
self.assertEqual(len(incompatible_keys.missing_keys), 1)
self.assertEqual(len(incompatible_keys.unexpected_keys), 0)
self.assertIn('linear1.weight', incompatible_keys.missing_keys)
state_dict.update({'extra.param': torch.ones(5)})
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
incompatible_keys = net.load_state_dict(state_dict, strict=False)
self.assertEqual(len(incompatible_keys.missing_keys), 1)
self.assertEqual(len(incompatible_keys.unexpected_keys), 1)
self.assertIn('linear1.weight', incompatible_keys.missing_keys)
self.assertIn('extra.param', incompatible_keys.unexpected_keys)
state_dict = net.state_dict()
state_dict.update({'bn.running_mean': torch.rand(14, 4)}) # wrong size
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict, strict=False))
state_dict = net.state_dict()
old_state_dict = deepcopy(state_dict)
state_dict = {
'linear1.weight': torch.ones(5, 5),
'block.conv1.bias': torch.arange(1, 4, dtype=conv1_bias_dtype),
'bn.running_mean': torch.randn(2),
'nonexistent_key': torch.rand(3)
}
net.load_state_dict(state_dict, strict=False)
self.assertEqual(net.linear1.weight, state_dict['linear1.weight'])
self.assertEqual(net.block.conv1.bias, state_dict['block.conv1.bias'])
self.assertEqual(net.bn.running_mean, state_dict['bn.running_mean'])
new_state_dict = net.state_dict()
del old_state_dict['linear1.weight']
del old_state_dict['block.conv1.bias']
del old_state_dict['bn.running_mean']
for k, v, in old_state_dict.items():
self.assertTrue(v.equal(new_state_dict[k]))
def test_load_state_dict_BC(self):
# BatchNormNd
# Added num_batches_tracked buffer at version 2. For state dict with
# earlier versions or no versions, it should provide default value of 0.
bn = nn.BatchNorm2d(3)
state_dict = bn.state_dict()
del state_dict['num_batches_tracked']
state_dict._metadata['']['version'] = 1 # version 1
bn.load_state_dict(state_dict)
self.assertEqual(bn.num_batches_tracked.dtype, torch.long)
self.assertEqual(bn.num_batches_tracked.item(), 0)
del state_dict._metadata['']['version'] # no version
bn.load_state_dict(state_dict)
self.assertEqual(bn.num_batches_tracked.dtype, torch.long)
self.assertEqual(bn.num_batches_tracked.item(), 0)
def test_load_state_dict_child(self):
base_module = nn.Linear(1, 1)
model = base_module
for _ in range(3):
model = nn.Sequential(*[deepcopy(model) for _ in range(10)])
def hook_fn(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
module_state_dict = module.state_dict()
self.assertEqual(len(module_state_dict.keys()), len(state_dict.keys()))
model[0][0]._register_load_state_dict_pre_hook(hook_fn, with_module=True)
model.load_state_dict(model.state_dict(), strict=True)
@unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows")
def test_register_state_dict_pre_hook_backward_compat(self):
called = False
def my_state_dict_pre_hook(*args, **kwargs):
nonlocal called
called = True
m = nn.Linear(1, 1)
self.assertTrue(hasattr(m, '_state_dict_pre_hooks'))
delattr(m, '_state_dict_pre_hooks')
# Save and load, ensure we can still call state_dict
# without running into issues.
with NamedTemporaryFile() as f:
# Note that torch.save / torch.load is not recommended
# to save / load modules.
torch.save(m, f.name)
m = torch.load(f.name)
# Ensure we can run state_dict without issues
_ = m.state_dict()
self.assertFalse(called)
m.register_state_dict_pre_hook(my_state_dict_pre_hook)
_ = m.state_dict()
self.assertTrue(called)
# FIXME: doesn't fail locally, maybe remove
@skipIfTorchDynamo("TorchDynamo fails here for unknown reasons")
def test_load_state_dict_ref_cycle(self):
# load_state_dict shouldn't cause a reference cycle involving Tensors
import gc
m = torch.nn.LSTM(16, 16, bidirectional=True)
gc.collect()
m.load_state_dict(deepcopy(m).state_dict())
refcycles = gc.collect()
self.assertEqual(refcycles, 0)
def test_load_state_dict_custom(self):
class CustomState(nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.ones(1))
self.sub = torch.nn.Linear(5, 5)
def _save_to_state_dict(self, destination, prefix, keep_vars):
destination[prefix + "serialized"] = self.param.data + 1
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys,
error_msgs):
# skip some of the error handling
self.param.data.copy_(state_dict[prefix + "serialized"] - 1)
# use sequential to verify nesting
m = nn.Sequential(CustomState())
with torch.no_grad():
m[0].param[0] = 10
m[0].sub.weight[0, 0] = 555
state_dict = m.state_dict()
self.assertEqual(state_dict["0.serialized"].item(), 11)
self.assertIn("0.sub.weight", state_dict)
self.assertNotIn("0.param", state_dict)
del m
mm = nn.Sequential(CustomState())
self.assertEqual(mm[0].param[0].item(), 1)
mm.load_state_dict(state_dict)
self.assertEqual(mm[0].param[0].item(), 10)
self.assertEqual(mm[0].sub.weight[0, 0].item(), 555)
def test_load_state_dict_assign_meta(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(3, 5)
self.bn = nn.BatchNorm1d(5)
def forward(self, input):
return self.bn(self.fc1(input))
net = MyModule()
state_dict = net.state_dict(keep_vars=True)
with torch.device('meta'):
net_meta = MyModule()
net_meta.load_state_dict(state_dict, assign=True)
# Make sure parameters and persistent buffers were assigned
net_meta_state_dict = net_meta.state_dict(keep_vars=True)
for key in state_dict.keys():
if isinstance(state_dict[key], torch.nn.Parameter):
self.assertTrue(state_dict[key] is net_meta_state_dict[key])
# Make sure that ordering of parameters and buffers is preserved
net_named_parameters = net.named_parameters()
net_named_buffers = net.named_buffers()
net_meta_named_parameters = net_meta.named_parameters()
net_meta_named_buffers = net_meta.named_buffers()
for p1, p2 in zip(net_named_parameters, net_meta_named_parameters):
n1, _ = p1
n2, _ = p2
self.assertEqual(n1, n2)
for p1, p2 in zip(net_named_buffers, net_meta_named_buffers):
n1, _ = p1
n2, _ = p2
self.assertEqual(n1, n2)
# Make sure outputs are the same
t = torch.randn(4, 3)
out_net = net(t)
out_net_meta = net_meta(t.clone())
self.assertEqual(out_net, out_net_meta)
def test_load_state_dict_assign_with_optimizer(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(3, 5)
self.bn = nn.BatchNorm1d(5)
def forward(self, input):
return self.bn(self.fc1(input))
net = MyModule()
opt = torch.optim.Adam(net.parameters(), lr=1000)
x = torch.randn(4, 3)
num_iters = 3
for i in range(num_iters):
opt.zero_grad()
out = net(x)
out.sum().backward()
opt.step()
opt_state_dict = deepcopy(opt.state_dict())
net_state_dict = deepcopy(net.state_dict())
with torch.device('meta'):
net_meta = MyModule()
net_meta.load_state_dict(net_state_dict, assign=True)
# must create optimizer only after loading state_dict when assign=True
opt2 = torch.optim.Adam(net_meta.parameters(), lr=1000)
opt2.load_state_dict(opt_state_dict)
y = x.clone()
for i in range(num_iters):
opt.zero_grad()
out = net(x)
out.sum().backward()
opt.step()
opt2.zero_grad()
out2 = net_meta(y)
out2.sum().backward()
opt2.step()
self.assertEqual(opt.state_dict(), opt2.state_dict())
self.assertEqual(net.state_dict(), net_meta.state_dict())
def test_load_state_dict_assign_shape_stride(self):
# Assigned tensor is allowed to have different properties than initial
# tensor except for shape
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(3, 5)
self.bn = nn.BatchNorm1d(5)
def forward(self, input):
return self.bn(self.fc1(input))
net = MyModule()
state_dict = net.state_dict()
# loading should be ok if stride is different
state_dict['fc1.weight'] = torch.randn(3, 5).transpose(0, 1)
net2 = MyModule()
net2.load_state_dict(state_dict, strict=False, assign=True)
state_dict['fc1.weight'] = torch.randn(2, 4)
with self.assertRaisesRegex(RuntimeError, "size mismatch for fc1.weight: copying a param with shape"):
net2.load_state_dict(state_dict, strict=False, assign=True)
def test_load_state_dict_warn_assign(self):
with torch.device('meta'):
m = torch.nn.Linear(3, 5)
state_dict = m.state_dict()
state_dict['weight'] = torch.empty_like(state_dict['weight'], device='cpu')
with self.assertWarnsRegex(UserWarning, "for weight: copying from a non-meta parameter in the checkpoint to a meta"):
m.load_state_dict(state_dict)
instantiate_parametrized_tests(TestLoadStateDict)
if __name__ == '__main__':
TestCase._default_dtype_check_enabled = True
run_tests()