| # Owner(s): ["module: nn"] |
| from copy import deepcopy |
| from itertools import product |
| import re |
| from tempfile import NamedTemporaryFile |
| import unittest |
| |
| import torch |
| import torch.nn as nn |
| from torch.testing._internal.common_nn import NNTestCase |
| from torch.testing._internal.common_utils import TestCase, \ |
| TEST_NUMPY, IS_WINDOWS, skipIfTorchDynamo, instantiate_parametrized_tests, \ |
| run_tests, skipIfCrossRef, swap |
| from torch.utils._pytree import tree_map |
| |
| if TEST_NUMPY: |
| import numpy as np |
| |
| |
| class TestLoadStateDict(NNTestCase): |
| _do_cuda_memory_leak_check = True |
| _do_cuda_non_default_stream = True |
| |
| @unittest.skipIf(not TEST_NUMPY, "numpy not found") |
| @swap([True, False]) |
| def test_load_state_dict_invalid(self): |
| m = torch.nn.Linear(2, 2, bias=False) |
| |
| state_dict = {'weight': np.random.randn(2, 2)} |
| with self.assertRaisesRegex(RuntimeError, |
| "expected torch.Tensor or Tensor-like object from checkpoint but received"): |
| m.load_state_dict(state_dict) |
| |
| state_dict = {'weight': ((1., 1.), (2., 2.))} |
| with self.assertRaisesRegex(RuntimeError, |
| "expected torch.Tensor or Tensor-like object from checkpoint but received"): |
| m.load_state_dict(state_dict) |
| |
| @swap([True, False]) |
| def test_load_state_dict_type(self): |
| m = nn.Module() |
| |
| with self.assertRaisesRegex(TypeError, |
| "Expected state_dict to be dict-like, got"): |
| m.load_state_dict("") |
| with self.assertRaisesRegex(TypeError, |
| "Expected state_dict to be dict-like, got"): |
| m.load_state_dict(2) |
| |
| @swap([True, False]) |
| @skipIfTorchDynamo("dynamo installs weakrefs on some params") |
| def test_load_state_dict(self): |
| l = nn.Linear(5, 5) |
| block = nn.Module() |
| block.conv1 = nn.Conv2d(3, 3, 3, bias=True) |
| block.conv2 = nn.Conv2d(3, 3, 3, bias=False) |
| net = nn.Module() |
| net.linear1 = l |
| net.linear2 = l |
| net.bn = nn.BatchNorm2d(2) |
| net.block = block |
| net.add_module('empty', None) |
| conv1_bias_dtype = block.conv1.bias.dtype |
| |
| state_dict = net.state_dict() |
| state_dict.update({ |
| 'linear1.weight': torch.ones(5, 5), |
| 'block.conv1.bias': torch.arange(1, 4, dtype=conv1_bias_dtype), |
| 'bn.running_mean': torch.randn(2), |
| }) |
| # Also test if a DDP state_dict can be loaded from a local model. |
| ddp_state_dict = net.state_dict() |
| ddp_state_dict.update({ |
| 'module.linear1.weight': torch.ones(5, 5), |
| 'module.block.conv1.bias': torch.arange(1, 4, dtype=conv1_bias_dtype), |
| 'module.bn.running_mean': torch.randn(2), |
| }) |
| torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(ddp_state_dict, 'module.') |
| for sd in [state_dict, ddp_state_dict]: |
| incompatible_keys = net.load_state_dict(sd) |
| self.assertEqual(len(incompatible_keys.missing_keys), 0) |
| self.assertEqual(len(incompatible_keys.unexpected_keys), 0) |
| self.assertNotIn('Incompatible', str(incompatible_keys)) |
| self.assertEqual(net.linear1.weight, sd['linear1.weight']) |
| self.assertEqual(net.block.conv1.bias, sd['block.conv1.bias']) |
| self.assertEqual(net.bn.running_mean, sd['bn.running_mean']) |
| |
| state_dict = net.state_dict() |
| state_dict.update({'extra': torch.ones(5)}) |
| self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict)) |
| incompatible_keys = net.load_state_dict(state_dict, strict=False) |
| self.assertEqual(len(incompatible_keys.missing_keys), 0) |
| self.assertEqual(len(incompatible_keys.unexpected_keys), 1) |
| self.assertIn('extra', incompatible_keys.unexpected_keys) |
| self.assertIn('Incompatible', str(incompatible_keys)) |
| |
| state_dict = net.state_dict() |
| state_dict.update({'extra.param': torch.ones(5)}) |
| self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict)) |
| incompatible_keys = net.load_state_dict(state_dict, strict=False) |
| self.assertEqual(len(incompatible_keys.missing_keys), 0) |
| self.assertEqual(len(incompatible_keys.unexpected_keys), 1) |
| self.assertIn('extra.param', incompatible_keys.unexpected_keys) |
| |
| state_dict = net.state_dict() |
| del state_dict['linear1.weight'] |
| self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict)) |
| incompatible_keys = net.load_state_dict(state_dict, strict=False) |
| self.assertEqual(len(incompatible_keys.missing_keys), 1) |
| self.assertEqual(len(incompatible_keys.unexpected_keys), 0) |
| self.assertIn('linear1.weight', incompatible_keys.missing_keys) |
| state_dict.update({'extra.param': torch.ones(5)}) |
| self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict)) |
| incompatible_keys = net.load_state_dict(state_dict, strict=False) |
| self.assertEqual(len(incompatible_keys.missing_keys), 1) |
| self.assertEqual(len(incompatible_keys.unexpected_keys), 1) |
| self.assertIn('linear1.weight', incompatible_keys.missing_keys) |
| self.assertIn('extra.param', incompatible_keys.unexpected_keys) |
| |
| state_dict = net.state_dict() |
| state_dict.update({'bn.running_mean': torch.rand(14, 4)}) # wrong size |
| self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict)) |
| self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict, strict=False)) |
| |
| state_dict = net.state_dict() |
| old_state_dict = deepcopy(state_dict) |
| state_dict = { |
| 'linear1.weight': torch.ones(5, 5), |
| 'block.conv1.bias': torch.arange(1, 4, dtype=conv1_bias_dtype), |
| 'bn.running_mean': torch.randn(2), |
| 'nonexistent_key': torch.rand(3) |
| } |
| net.load_state_dict(state_dict, strict=False) |
| self.assertEqual(net.linear1.weight, state_dict['linear1.weight']) |
| self.assertEqual(net.block.conv1.bias, state_dict['block.conv1.bias']) |
| self.assertEqual(net.bn.running_mean, state_dict['bn.running_mean']) |
| new_state_dict = net.state_dict() |
| del old_state_dict['linear1.weight'] |
| del old_state_dict['block.conv1.bias'] |
| del old_state_dict['bn.running_mean'] |
| for k, v, in old_state_dict.items(): |
| self.assertTrue(v.equal(new_state_dict[k])) |
| |
| @swap([True, False]) |
| def test_load_state_dict_BC(self): |
| # BatchNormNd |
| # Added num_batches_tracked buffer at version 2. For state dict with |
| # earlier versions or no versions, it should provide default value of 0. |
| bn = nn.BatchNorm2d(3) |
| state_dict = bn.state_dict() |
| del state_dict['num_batches_tracked'] |
| state_dict._metadata['']['version'] = 1 # version 1 |
| bn.load_state_dict(state_dict) |
| self.assertEqual(bn.num_batches_tracked.dtype, torch.long) |
| self.assertEqual(bn.num_batches_tracked.item(), 0) |
| del state_dict._metadata['']['version'] # no version |
| bn.load_state_dict(state_dict) |
| self.assertEqual(bn.num_batches_tracked.dtype, torch.long) |
| self.assertEqual(bn.num_batches_tracked.item(), 0) |
| |
| @swap([True, False]) |
| def test_load_state_dict_child(self): |
| base_module = nn.Linear(1, 1) |
| model = base_module |
| for _ in range(3): |
| model = nn.Sequential(*[deepcopy(model) for _ in range(10)]) |
| |
| def hook_fn(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): |
| module_state_dict = module.state_dict() |
| self.assertEqual(len(module_state_dict.keys()), len(state_dict.keys())) |
| |
| model[0][0]._register_load_state_dict_pre_hook(hook_fn, with_module=True) |
| model.load_state_dict(model.state_dict(), strict=True) |
| |
| @unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows") |
| @swap([True, False]) |
| def test_register_state_dict_pre_hook_backward_compat(self): |
| called = False |
| |
| def my_state_dict_pre_hook(*args, **kwargs): |
| nonlocal called |
| called = True |
| |
| m = nn.Linear(1, 1) |
| self.assertTrue(hasattr(m, '_state_dict_pre_hooks')) |
| delattr(m, '_state_dict_pre_hooks') |
| # Save and load, ensure we can still call state_dict |
| # without running into issues. |
| with NamedTemporaryFile() as f: |
| # Note that torch.save / torch.load is not recommended |
| # to save / load modules. |
| torch.save(m, f.name) |
| m = torch.load(f.name) |
| |
| # Ensure we can run state_dict without issues |
| _ = m.state_dict() |
| self.assertFalse(called) |
| m.register_state_dict_pre_hook(my_state_dict_pre_hook) |
| _ = m.state_dict() |
| self.assertTrue(called) |
| |
| # fails swapping as LSTM installs weak references on the parameters |
| @swap([False]) |
| @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons") |
| def test_load_state_dict_ref_cycle(self): |
| # load_state_dict shouldn't cause a reference cycle involving Tensors |
| import gc |
| |
| m = torch.nn.LSTM(16, 16, bidirectional=True) |
| |
| gc.collect() |
| m.load_state_dict(deepcopy(m).state_dict()) |
| refcycles = gc.collect() |
| |
| self.assertEqual(refcycles, 0) |
| |
| @swap([True, False]) |
| def test_load_state_dict_custom(self): |
| |
| class CustomState(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.param = torch.nn.Parameter(torch.ones(1)) |
| self.sub = torch.nn.Linear(5, 5) |
| |
| def _save_to_state_dict(self, destination, prefix, keep_vars): |
| destination[prefix + "serialized"] = self.param.data + 1 |
| |
| def _load_from_state_dict(self, state_dict, prefix, local_metadata, |
| strict, missing_keys, unexpected_keys, |
| error_msgs): |
| # skip some of the error handling |
| self.param.data.copy_(state_dict[prefix + "serialized"] - 1) |
| |
| # use sequential to verify nesting |
| m = nn.Sequential(CustomState()) |
| with torch.no_grad(): |
| m[0].param[0] = 10 |
| m[0].sub.weight[0, 0] = 555 |
| state_dict = m.state_dict() |
| self.assertEqual(state_dict["0.serialized"].item(), 11) |
| self.assertIn("0.sub.weight", state_dict) |
| self.assertNotIn("0.param", state_dict) |
| del m |
| mm = nn.Sequential(CustomState()) |
| self.assertEqual(mm[0].param[0].item(), 1) |
| mm.load_state_dict(state_dict) |
| self.assertEqual(mm[0].param[0].item(), 10) |
| self.assertEqual(mm[0].sub.weight[0, 0].item(), 555) |
| |
| @swap([True, False]) |
| def test_load_state_dict_assign_meta(self): |
| class MyModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.fc1 = nn.Linear(3, 5) |
| self.bn = nn.BatchNorm1d(5) |
| |
| def forward(self, input): |
| return self.bn(self.fc1(input)) |
| |
| net = MyModule() |
| state_dict = net.state_dict(keep_vars=True) |
| |
| with torch.device('meta'): |
| net_meta = MyModule() |
| |
| net_meta.load_state_dict(state_dict, assign=True) |
| |
| # Make sure parameters and persistent buffers were assigned |
| net_meta_state_dict = net_meta.state_dict(keep_vars=True) |
| for key in state_dict.keys(): |
| if isinstance(state_dict[key], torch.nn.Parameter): |
| self.assertTrue(state_dict[key] is net_meta_state_dict[key]) |
| |
| # Make sure that ordering of parameters and buffers is preserved |
| net_named_parameters = net.named_parameters() |
| net_named_buffers = net.named_buffers() |
| net_meta_named_parameters = net_meta.named_parameters() |
| net_meta_named_buffers = net_meta.named_buffers() |
| |
| for p1, p2 in zip(net_named_parameters, net_meta_named_parameters): |
| n1, _ = p1 |
| n2, _ = p2 |
| self.assertEqual(n1, n2) |
| |
| for p1, p2 in zip(net_named_buffers, net_meta_named_buffers): |
| n1, _ = p1 |
| n2, _ = p2 |
| self.assertEqual(n1, n2) |
| |
| # Make sure outputs are the same |
| t = torch.randn(4, 3) |
| out_net = net(t) |
| out_net_meta = net_meta(t.clone()) |
| |
| self.assertEqual(out_net, out_net_meta) |
| |
| @swap([True, False]) |
| def test_load_state_dict_assign_with_optimizer(self): |
| class MyModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.fc1 = nn.Linear(3, 5) |
| self.bn = nn.BatchNorm1d(5) |
| |
| def forward(self, input): |
| return self.bn(self.fc1(input)) |
| |
| net = MyModule() |
| opt = torch.optim.Adam(net.parameters(), lr=1000) |
| x = torch.randn(4, 3) |
| num_iters = 3 |
| |
| for i in range(num_iters): |
| opt.zero_grad() |
| out = net(x) |
| out.sum().backward() |
| opt.step() |
| |
| opt_state_dict = deepcopy(opt.state_dict()) |
| net_state_dict = deepcopy(net.state_dict()) |
| |
| with torch.device('meta'): |
| net_meta = MyModule() |
| |
| net_meta.load_state_dict(net_state_dict, assign=True) |
| # must create optimizer only after loading state_dict when assign=True |
| opt2 = torch.optim.Adam(net_meta.parameters(), lr=1000) |
| opt2.load_state_dict(opt_state_dict) |
| |
| y = x.clone() |
| for i in range(num_iters): |
| opt.zero_grad() |
| out = net(x) |
| out.sum().backward() |
| opt.step() |
| |
| opt2.zero_grad() |
| out2 = net_meta(y) |
| out2.sum().backward() |
| opt2.step() |
| |
| self.assertEqual(opt.state_dict(), opt2.state_dict()) |
| self.assertEqual(net.state_dict(), net_meta.state_dict()) |
| |
| @swap([True, False]) |
| def test_load_state_dict_assign_shape_stride(self): |
| # Assigned tensor is allowed to have different properties than initial |
| # tensor except for shape |
| class MyModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.fc1 = nn.Linear(3, 5) |
| self.bn = nn.BatchNorm1d(5) |
| |
| def forward(self, input): |
| return self.bn(self.fc1(input)) |
| |
| net = MyModule() |
| state_dict = net.state_dict() |
| # loading should be ok if stride is different |
| state_dict['fc1.weight'] = torch.randn(3, 5).transpose(0, 1) |
| net2 = MyModule() |
| net2.load_state_dict(state_dict, strict=False, assign=True) |
| |
| state_dict['fc1.weight'] = torch.randn(2, 4) |
| with self.assertRaisesRegex(RuntimeError, "size mismatch for fc1.weight: copying a param with shape"): |
| net2.load_state_dict(state_dict, strict=False, assign=True) |
| |
| @swap([True, False]) |
| def test_load_state_dict_warn_assign(self): |
| with torch.device('meta'): |
| m = torch.nn.Linear(3, 5) |
| state_dict = m.state_dict() |
| state_dict['weight'] = torch.empty_like(state_dict['weight'], device='cpu') |
| with self.assertWarnsRegex(UserWarning, "for weight: copying from a non-meta parameter in the checkpoint to a meta"): |
| m.load_state_dict(state_dict) |
| |
| |
| def load_torch_function_handler(cls, func, types, args=(), kwargs=None): |
| kwargs = {} if kwargs is None else kwargs |
| |
| def module_load(dest, src): |
| # always convert src to cls |
| if isinstance(dest, cls): |
| if type(src) is torch.Tensor: |
| return cls(src) |
| elif type(src) is cls: |
| return src.detach() |
| else: |
| if isinstance(src, MyWrapperLoadTensor): |
| return cls(src._data) |
| return cls(src) |
| else: |
| return src.detach() |
| |
| if func is torch.Tensor.module_load: |
| return module_load(*args, **kwargs) |
| else: |
| with torch._C.DisableTorchFunctionSubclass(): |
| # detach must return instance of same subclass for nn.Parameter() |
| if func == torch.Tensor.detach: |
| ret = func(*args, **kwargs) |
| if not isinstance(ret, cls): |
| return cls(ret) |
| return ret |
| return func(*args, **kwargs) |
| |
| class MyLoadTensor(torch.Tensor): |
| @classmethod |
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
| return load_torch_function_handler(cls, func, types, args, kwargs) |
| |
| # We use MyLoadTensor2 to test tensor subclass, wrapper tensor subclass |
| # where neither inherits from each other |
| class MyLoadTensor2(torch.Tensor): |
| @classmethod |
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
| return load_torch_function_handler(cls, func, types, args, kwargs) |
| |
| class MyBrokenLoadTensor(torch.Tensor): |
| @classmethod |
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
| kwargs = {} if kwargs is None else kwargs |
| |
| if func is torch.Tensor.module_load: |
| # wrong as this doesn't detach! |
| return args[1] |
| else: |
| with torch._C.DisableTorchFunctionSubclass(): |
| # detach must return instance of same subclass for nn.Parameter() |
| if func == torch.Tensor.detach: |
| return cls(func(*args, **kwargs)) |
| return func(*args, **kwargs) |
| |
| class MyWrapperLoadTensor(MyLoadTensor): |
| @staticmethod |
| def __new__(cls, data: torch.Tensor): |
| t = torch.Tensor._make_wrapper_subclass( |
| cls, data.size(), |
| dtype=data.dtype, layout=data.layout, |
| device=data.device, requires_grad=data.requires_grad, |
| strides=data.stride(), storage_offset=data.storage_offset()) |
| return t |
| |
| def __init__(self, data: torch.Tensor): |
| self._data = data |
| |
| def __repr__(self): |
| return f"MyWrapperLoadTensor({self._data.__repr__()})" |
| |
| @classmethod |
| def __torch_dispatch__(cls, func, types, args=(), kwargs=None): |
| |
| def unwrap(t): |
| return t._data if isinstance(t, MyWrapperLoadTensor) else t |
| |
| def wrap(t): |
| return MyWrapperLoadTensor(t) if isinstance(t, torch.Tensor) else t |
| |
| kwargs = {} if kwargs is None else kwargs |
| out = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) |
| return tree_map(wrap, out) |
| |
| |
| class TestLoadStateDictSwap(TestCase): |
| @skipIfCrossRef |
| @skipIfTorchDynamo("Can't swap with dynamo as dynamo installs weakrefs") |
| @swap([True]) |
| def test_swap_subclass(self): |
| |
| def _create_model(subclass=None): |
| m = torch.nn.Linear(2, 3, bias=False) |
| m.register_buffer('buf', torch.randn(2, 3)) |
| if subclass is not None: |
| m.weight = torch.nn.Parameter(subclass(m.weight)) |
| m.buf = subclass(m.buf) |
| return m |
| |
| def _test(m_subclass=None, sd_subclass=None): |
| m = _create_model(m_subclass) |
| sd = _create_model(sd_subclass).state_dict() |
| sd = sd |
| m.load_state_dict(sd) |
| self.assertEqual(m.weight, sd['weight']) |
| self.assertEqual(m.buf, sd['buf']) |
| self.assertTrue(isinstance(m.weight, torch.nn.Parameter)) |
| self.assertTrue(not isinstance(m.buf, torch.nn.Parameter)) |
| |
| weight_type, buf_type = (torch.nn.Parameter, torch.Tensor) |
| if m_subclass is not None and sd_subclass is not None: |
| # handler of subclass takes precedence over superclass |
| if issubclass(sd_subclass, m_subclass): |
| weight_type, buf_type = (sd_subclass, sd_subclass) |
| else: |
| weight_type, buf_type = (m_subclass, m_subclass) |
| elif m_subclass is not None: |
| weight_type, buf_type = (m_subclass, m_subclass) |
| elif sd_subclass is not None: |
| weight_type, buf_type = (sd_subclass, sd_subclass) |
| self.assertTrue(type(m.weight) is weight_type) |
| self.assertTrue(type(m.buf) is buf_type) |
| |
| # (MyLoadTensor, MyWrapperLoadTensor) tests the behavior of (superclass, subclass) |
| subclasses = [None, MyLoadTensor, MyLoadTensor2, MyWrapperLoadTensor] |
| for m_s, sd_s in product(subclasses, subclasses): |
| _test(m_s, sd_s) |
| |
| # MyBrokenLoadTensor should error since its module_load doesn't call .detach() |
| with self.assertRaisesRegex(RuntimeError, re.escape("Error(s) in loading state_dict for Linear:")): |
| _test(None, MyBrokenLoadTensor) |
| |
| |
| instantiate_parametrized_tests(TestLoadStateDict) |
| instantiate_parametrized_tests(TestLoadStateDictSwap) |
| |
| if __name__ == '__main__': |
| TestCase._default_dtype_check_enabled = True |
| run_tests() |