| # Owner(s): ["module: nn"] |
| import re |
| import unittest |
| from copy import deepcopy |
| from itertools import product |
| from tempfile import NamedTemporaryFile |
| |
| import torch |
| import torch.nn as nn |
| from torch.testing._internal.common_nn import NNTestCase |
| from torch.testing._internal.common_utils import ( |
| instantiate_parametrized_tests, |
| IS_WINDOWS, |
| parametrize, |
| run_tests, |
| skipIfCrossRef, |
| skipIfTorchDynamo, |
| swap, |
| TEST_NUMPY, |
| TestCase, |
| ) |
| from torch.utils._pytree import tree_map |
| |
| if TEST_NUMPY: |
| import numpy as np |
| |
| |
| class TestLoadStateDict(NNTestCase): |
| _do_cuda_memory_leak_check = True |
| _do_cuda_non_default_stream = True |
| |
| @unittest.skipIf(not TEST_NUMPY, "numpy not found") |
| @swap([True, False]) |
| def test_load_state_dict_invalid(self): |
| m = torch.nn.Linear(2, 2, bias=False) |
| |
| state_dict = {"weight": np.random.randn(2, 2)} |
| with self.assertRaisesRegex( |
| RuntimeError, |
| "expected torch.Tensor or Tensor-like object from checkpoint but received", |
| ): |
| m.load_state_dict(state_dict) |
| |
| state_dict = {"weight": ((1.0, 1.0), (2.0, 2.0))} |
| with self.assertRaisesRegex( |
| RuntimeError, |
| "expected torch.Tensor or Tensor-like object from checkpoint but received", |
| ): |
| m.load_state_dict(state_dict) |
| |
| @swap([True, False]) |
| def test_load_state_dict_type(self): |
| m = nn.Module() |
| |
| with self.assertRaisesRegex( |
| TypeError, "Expected state_dict to be dict-like, got" |
| ): |
| m.load_state_dict("") |
| with self.assertRaisesRegex( |
| TypeError, "Expected state_dict to be dict-like, got" |
| ): |
| m.load_state_dict(2) |
| |
| @swap([True, False]) |
| @skipIfTorchDynamo("dynamo installs weakrefs on some params") |
| def test_load_state_dict(self): |
| l = nn.Linear(5, 5) |
| block = nn.Module() |
| block.conv1 = nn.Conv2d(3, 3, 3, bias=True) |
| block.conv2 = nn.Conv2d(3, 3, 3, bias=False) |
| net = nn.Module() |
| net.linear1 = l |
| net.linear2 = l |
| net.bn = nn.BatchNorm2d(2) |
| net.block = block |
| net.add_module("empty", None) |
| conv1_bias_dtype = block.conv1.bias.dtype |
| |
| state_dict = net.state_dict() |
| state_dict.update( |
| { |
| "linear1.weight": torch.ones(5, 5), |
| "block.conv1.bias": torch.arange(1, 4, dtype=conv1_bias_dtype), |
| "bn.running_mean": torch.randn(2), |
| } |
| ) |
| # Also test if a DDP state_dict can be loaded from a local model. |
| ddp_state_dict = net.state_dict() |
| ddp_state_dict.update( |
| { |
| "module.linear1.weight": torch.ones(5, 5), |
| "module.block.conv1.bias": torch.arange(1, 4, dtype=conv1_bias_dtype), |
| "module.bn.running_mean": torch.randn(2), |
| } |
| ) |
| torch.nn.modules.utils.consume_prefix_in_state_dict_if_present( |
| ddp_state_dict, "module." |
| ) |
| for sd in [state_dict, ddp_state_dict]: |
| incompatible_keys = net.load_state_dict(sd) |
| self.assertEqual(len(incompatible_keys.missing_keys), 0) |
| self.assertEqual(len(incompatible_keys.unexpected_keys), 0) |
| self.assertNotIn("Incompatible", str(incompatible_keys)) |
| self.assertEqual(net.linear1.weight, sd["linear1.weight"]) |
| self.assertEqual(net.block.conv1.bias, sd["block.conv1.bias"]) |
| self.assertEqual(net.bn.running_mean, sd["bn.running_mean"]) |
| |
| state_dict = net.state_dict() |
| state_dict.update({"extra": torch.ones(5)}) |
| self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict)) |
| incompatible_keys = net.load_state_dict(state_dict, strict=False) |
| self.assertEqual(len(incompatible_keys.missing_keys), 0) |
| self.assertEqual(len(incompatible_keys.unexpected_keys), 1) |
| self.assertIn("extra", incompatible_keys.unexpected_keys) |
| self.assertIn("Incompatible", str(incompatible_keys)) |
| |
| state_dict = net.state_dict() |
| state_dict.update({"extra.param": torch.ones(5)}) |
| self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict)) |
| incompatible_keys = net.load_state_dict(state_dict, strict=False) |
| self.assertEqual(len(incompatible_keys.missing_keys), 0) |
| self.assertEqual(len(incompatible_keys.unexpected_keys), 1) |
| self.assertIn("extra.param", incompatible_keys.unexpected_keys) |
| |
| state_dict = net.state_dict() |
| del state_dict["linear1.weight"] |
| self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict)) |
| incompatible_keys = net.load_state_dict(state_dict, strict=False) |
| self.assertEqual(len(incompatible_keys.missing_keys), 1) |
| self.assertEqual(len(incompatible_keys.unexpected_keys), 0) |
| self.assertIn("linear1.weight", incompatible_keys.missing_keys) |
| state_dict.update({"extra.param": torch.ones(5)}) |
| self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict)) |
| incompatible_keys = net.load_state_dict(state_dict, strict=False) |
| self.assertEqual(len(incompatible_keys.missing_keys), 1) |
| self.assertEqual(len(incompatible_keys.unexpected_keys), 1) |
| self.assertIn("linear1.weight", incompatible_keys.missing_keys) |
| self.assertIn("extra.param", incompatible_keys.unexpected_keys) |
| |
| state_dict = net.state_dict() |
| state_dict.update({"bn.running_mean": torch.rand(14, 4)}) # wrong size |
| self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict)) |
| self.assertRaises( |
| RuntimeError, lambda: net.load_state_dict(state_dict, strict=False) |
| ) |
| |
| state_dict = net.state_dict() |
| old_state_dict = deepcopy(state_dict) |
| state_dict = { |
| "linear1.weight": torch.ones(5, 5), |
| "block.conv1.bias": torch.arange(1, 4, dtype=conv1_bias_dtype), |
| "bn.running_mean": torch.randn(2), |
| "nonexistent_key": torch.rand(3), |
| } |
| net.load_state_dict(state_dict, strict=False) |
| self.assertEqual(net.linear1.weight, state_dict["linear1.weight"]) |
| self.assertEqual(net.block.conv1.bias, state_dict["block.conv1.bias"]) |
| self.assertEqual(net.bn.running_mean, state_dict["bn.running_mean"]) |
| new_state_dict = net.state_dict() |
| del old_state_dict["linear1.weight"] |
| del old_state_dict["block.conv1.bias"] |
| del old_state_dict["bn.running_mean"] |
| for ( |
| k, |
| v, |
| ) in old_state_dict.items(): |
| self.assertTrue(v.equal(new_state_dict[k])) |
| |
| @swap([True, False]) |
| def test_load_state_dict_BC(self): |
| # BatchNormNd |
| # Added num_batches_tracked buffer at version 2. For state dict with |
| # earlier versions or no versions, it should provide default value of 0. |
| bn = nn.BatchNorm2d(3) |
| state_dict = bn.state_dict() |
| del state_dict["num_batches_tracked"] |
| state_dict._metadata[""]["version"] = 1 # version 1 |
| bn.load_state_dict(state_dict) |
| self.assertEqual(bn.num_batches_tracked.dtype, torch.long) |
| self.assertEqual(bn.num_batches_tracked.item(), 0) |
| del state_dict._metadata[""]["version"] # no version |
| bn.load_state_dict(state_dict) |
| self.assertEqual(bn.num_batches_tracked.dtype, torch.long) |
| self.assertEqual(bn.num_batches_tracked.item(), 0) |
| |
| @swap([True, False]) |
| def test_load_state_dict_child(self): |
| base_module = nn.Linear(1, 1) |
| model = base_module |
| for _ in range(3): |
| model = nn.Sequential(*[deepcopy(model) for _ in range(10)]) |
| |
| def hook_fn( |
| module, |
| state_dict, |
| prefix, |
| local_metadata, |
| strict, |
| missing_keys, |
| unexpected_keys, |
| error_msgs, |
| ): |
| module_state_dict = module.state_dict() |
| self.assertEqual(len(module_state_dict.keys()), len(state_dict.keys())) |
| |
| model[0][0]._register_load_state_dict_pre_hook(hook_fn, with_module=True) |
| model.load_state_dict(model.state_dict(), strict=True) |
| |
| @unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows") |
| @swap([True, False]) |
| def test_register_state_dict_pre_hook_backward_compat(self): |
| called = False |
| |
| def my_state_dict_pre_hook(*args, **kwargs): |
| nonlocal called |
| called = True |
| |
| m = nn.Linear(1, 1) |
| self.assertTrue(hasattr(m, "_state_dict_pre_hooks")) |
| delattr(m, "_state_dict_pre_hooks") |
| # Save and load, ensure we can still call state_dict |
| # without running into issues. |
| with NamedTemporaryFile() as f: |
| # Note that torch.save / torch.load is not recommended |
| # to save / load modules. |
| torch.save(m, f.name) |
| m = torch.load(f.name) |
| |
| # Ensure we can run state_dict without issues |
| _ = m.state_dict() |
| self.assertFalse(called) |
| m.register_state_dict_pre_hook(my_state_dict_pre_hook) |
| _ = m.state_dict() |
| self.assertTrue(called) |
| |
| # fails swapping as LSTM installs weak references on the parameters |
| @swap([False]) |
| @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons") |
| def test_load_state_dict_ref_cycle(self): |
| # load_state_dict shouldn't cause a reference cycle involving Tensors |
| import gc |
| |
| m = torch.nn.LSTM(16, 16, bidirectional=True) |
| |
| gc.collect() |
| m.load_state_dict(deepcopy(m).state_dict()) |
| refcycles = gc.collect() |
| |
| self.assertEqual(refcycles, 0) |
| |
| @swap([True, False]) |
| def test_load_state_dict_custom(self): |
| class CustomState(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.param = torch.nn.Parameter(torch.ones(1)) |
| self.sub = torch.nn.Linear(5, 5) |
| |
| def _save_to_state_dict(self, destination, prefix, keep_vars): |
| destination[prefix + "serialized"] = self.param.data + 1 |
| |
| def _load_from_state_dict( |
| self, |
| state_dict, |
| prefix, |
| local_metadata, |
| strict, |
| missing_keys, |
| unexpected_keys, |
| error_msgs, |
| ): |
| # skip some of the error handling |
| self.param.data.copy_(state_dict[prefix + "serialized"] - 1) |
| |
| # use sequential to verify nesting |
| m = nn.Sequential(CustomState()) |
| with torch.no_grad(): |
| m[0].param[0] = 10 |
| m[0].sub.weight[0, 0] = 555 |
| state_dict = m.state_dict() |
| self.assertEqual(state_dict["0.serialized"].item(), 11) |
| self.assertIn("0.sub.weight", state_dict) |
| self.assertNotIn("0.param", state_dict) |
| del m |
| mm = nn.Sequential(CustomState()) |
| self.assertEqual(mm[0].param[0].item(), 1) |
| mm.load_state_dict(state_dict) |
| self.assertEqual(mm[0].param[0].item(), 10) |
| self.assertEqual(mm[0].sub.weight[0, 0].item(), 555) |
| |
| @swap([True, False]) |
| @parametrize("keep_vars", [True, False]) |
| def test_load_state_dict_assign_meta(self, keep_vars): |
| class MyModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.fc1 = nn.Linear(3, 5) |
| self.bn = nn.BatchNorm1d(5) |
| self.x = nn.Parameter(torch.rand(5), requires_grad=False) |
| |
| def forward(self, input): |
| return self.x + self.bn(self.fc1(input)) |
| |
| swap = torch.__future__.get_swap_module_params_on_conversion() |
| net = MyModule() |
| state_dict = net.state_dict(keep_vars=keep_vars) |
| for v in state_dict.values(): |
| v.requires_grad_(False) |
| |
| with torch.device("meta"): |
| net_meta = MyModule() |
| |
| net_meta_state_dict_old = net_meta.state_dict(keep_vars=True) |
| net_meta.load_state_dict(state_dict, assign=True) |
| |
| # Make sure parameters and persistent buffers were assigned |
| net_meta_state_dict = net_meta.state_dict(keep_vars=True) |
| for key in state_dict.keys(): |
| if key in net_meta._parameters: |
| if keep_vars and not swap: |
| # state_dict[key] is an nn.Parameter |
| self.assertTrue(state_dict[key] is net_meta_state_dict[key]) |
| else: |
| if swap: |
| self.assertTrue( |
| net_meta_state_dict[key] is net_meta_state_dict_old[key] |
| ) |
| else: |
| # state_dict[key] is not an nn.Parameter so it will be detached when wrapping with a Parameter |
| self.assertTrue( |
| net_meta_state_dict[key] is not net_meta_state_dict_old[key] |
| ) |
| self.assertEqual( |
| net_meta_state_dict_old[key].requires_grad, |
| net_meta_state_dict[key].requires_grad, |
| ) |
| self.assertEqual( |
| net_meta_state_dict_old[key].requires_grad, |
| net_meta_state_dict[key].requires_grad, |
| ) |
| self.assertEqual(state_dict[key], net_meta_state_dict[key]) |
| elif ( |
| key in net_meta._buffers |
| and key not in net_meta._non_persistent_buffers_set |
| ): |
| self.assertTrue(state_dict[key] is net_meta_state_dict[key]) |
| self.assertEqual(state_dict[key], net_meta_state_dict[key]) |
| |
| # Make sure that ordering of parameters and buffers is preserved |
| net_named_parameters = net.named_parameters() |
| net_named_buffers = net.named_buffers() |
| net_meta_named_parameters = net_meta.named_parameters() |
| net_meta_named_buffers = net_meta.named_buffers() |
| |
| for (n1, _), (n2, _) in zip(net_named_parameters, net_meta_named_parameters): |
| self.assertEqual(n1, n2) |
| |
| for (n1, _), (n2, _) in zip(net_named_buffers, net_meta_named_buffers): |
| self.assertEqual(n1, n2) |
| |
| # Make sure outputs are the same |
| t = torch.randn(4, 3) |
| out_net = net(t) |
| out_net_meta = net_meta(t.clone()) |
| |
| self.assertEqual(out_net, out_net_meta) |
| |
| @swap([True, False]) |
| def test_load_state_dict_assign_with_optimizer(self): |
| class MyModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.fc1 = nn.Linear(3, 5) |
| self.bn = nn.BatchNorm1d(5) |
| |
| def forward(self, input): |
| return self.bn(self.fc1(input)) |
| |
| net = MyModule() |
| opt = torch.optim.Adam(net.parameters(), lr=1000) |
| x = torch.randn(4, 3) |
| num_iters = 3 |
| |
| for i in range(num_iters): |
| opt.zero_grad() |
| out = net(x) |
| out.sum().backward() |
| opt.step() |
| |
| opt_state_dict = deepcopy(opt.state_dict()) |
| net_state_dict = deepcopy(net.state_dict()) |
| |
| with torch.device("meta"): |
| net_meta = MyModule() |
| |
| net_meta.load_state_dict(net_state_dict, assign=True) |
| # must create optimizer only after loading state_dict when assign=True |
| opt2 = torch.optim.Adam(net_meta.parameters(), lr=1000) |
| opt2.load_state_dict(opt_state_dict) |
| |
| y = x.clone() |
| for i in range(num_iters): |
| opt.zero_grad() |
| out = net(x) |
| out.sum().backward() |
| opt.step() |
| |
| opt2.zero_grad() |
| out2 = net_meta(y) |
| out2.sum().backward() |
| opt2.step() |
| |
| self.assertEqual(opt.state_dict(), opt2.state_dict()) |
| self.assertEqual(net.state_dict(), net_meta.state_dict()) |
| |
| @swap([True, False]) |
| def test_load_state_dict_assign_shape_stride(self): |
| # Assigned tensor is allowed to have different properties than initial |
| # tensor except for shape |
| class MyModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.fc1 = nn.Linear(3, 5) |
| self.bn = nn.BatchNorm1d(5) |
| |
| def forward(self, input): |
| return self.bn(self.fc1(input)) |
| |
| net = MyModule() |
| state_dict = net.state_dict() |
| # loading should be ok if stride is different |
| state_dict["fc1.weight"] = torch.randn(3, 5).transpose(0, 1) |
| net2 = MyModule() |
| net2.load_state_dict(state_dict, strict=False, assign=True) |
| |
| state_dict["fc1.weight"] = torch.randn(2, 4) |
| with self.assertRaisesRegex( |
| RuntimeError, "size mismatch for fc1.weight: copying a param with shape" |
| ): |
| net2.load_state_dict(state_dict, strict=False, assign=True) |
| |
| @swap([True, False]) |
| def test_load_state_dict_warn_assign(self): |
| with torch.device("meta"): |
| m = torch.nn.Linear(3, 5) |
| state_dict = m.state_dict() |
| state_dict["weight"] = torch.empty_like(state_dict["weight"], device="cpu") |
| with self.assertWarnsRegex( |
| UserWarning, |
| "for weight: copying from a non-meta parameter in the checkpoint to a meta", |
| ): |
| m.load_state_dict(state_dict) |
| |
| @swap([True, False]) |
| def test_load_state_dict_with_unexpected_key(self): |
| class MyModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.fc1 = torch.nn.Linear(5, 10) |
| |
| m = MyModule() |
| |
| # Unexpected key & strict = True |
| with self.assertRaisesRegex(RuntimeError, "Unexpected key"): |
| state_dict = m.state_dict() |
| state_dict["fc1.bad_suffix"] = torch.randn(5, 10) |
| m.load_state_dict(state_dict) |
| |
| # Unexpected key & strict = False |
| state_dict = m.load_state_dict(state_dict, strict=False) |
| self.assertIn("fc1.bad_suffix", state_dict.unexpected_keys) |
| |
| # Unexpected key whose prefix matches a valid key & strict = True |
| with self.assertRaisesRegex(RuntimeError, "Unexpected key"): |
| state_dict = m.state_dict() |
| state_dict["fc1.weight.bad_suffix"] = torch.randn(5, 10) |
| m.load_state_dict(state_dict) |
| |
| # Unexpected key whose prefix matches a valid key & strict = False |
| state_dict = m.load_state_dict(state_dict, strict=False) |
| self.assertIn("fc1.weight.bad_suffix", state_dict.unexpected_keys) |
| |
| |
| def load_torch_function_handler(cls, func, types, args=(), kwargs=None): |
| kwargs = {} if kwargs is None else kwargs |
| |
| def module_load(dest, src, assign=False): |
| if isinstance(dest, cls): |
| if assign: |
| return src.detach() |
| else: |
| if type(src) is torch.Tensor: |
| return cls(src) |
| elif type(src) is cls: |
| return src.detach() |
| else: |
| if isinstance(src, MyWrapperLoadTensor): |
| return cls(src._data) |
| return cls(src) |
| else: |
| assert isinstance( |
| src, cls |
| ), f"Expected isinstance(src, {cls}) but got {type(src)}" |
| assert ( |
| type(dest) == torch.Tensor |
| or type(dest) == torch.nn.Parameter |
| or issubclass(cls, type(dest)) |
| ) |
| if assign: |
| return src.detach() |
| else: |
| if isinstance(src, MyWrapperLoadTensor): |
| if type(dest) not in {torch.Tensor, torch.nn.Parameter}: |
| return type(dest)(src._data) |
| else: |
| return src._data.detach() |
| else: |
| return torch.Tensor(src) |
| |
| if func is torch.Tensor.module_load: |
| return module_load(*args, **kwargs) |
| else: |
| with torch._C.DisableTorchFunctionSubclass(): |
| # detach must return instance of same subclass for nn.Parameter() |
| if func == torch.Tensor.detach: |
| ret = func(*args, **kwargs) |
| if not isinstance(ret, cls): |
| return cls(ret) |
| return ret |
| return func(*args, **kwargs) |
| |
| |
| class MyLoadTensor(torch.Tensor): |
| @classmethod |
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
| return load_torch_function_handler(cls, func, types, args, kwargs) |
| |
| |
| # We use MyLoadTensor2 to test tensor subclass, wrapper tensor subclass |
| # where neither inherits from each other |
| class MyLoadTensor2(torch.Tensor): |
| @classmethod |
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
| return load_torch_function_handler(cls, func, types, args, kwargs) |
| |
| |
| class MyBrokenLoadTensor(torch.Tensor): |
| @classmethod |
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
| kwargs = {} if kwargs is None else kwargs |
| |
| if func is torch.Tensor.module_load: |
| # wrong as this doesn't detach! |
| return args[1] |
| else: |
| with torch._C.DisableTorchFunctionSubclass(): |
| # detach must return instance of same subclass for nn.Parameter() |
| if func == torch.Tensor.detach: |
| return cls(func(*args, **kwargs)) |
| return func(*args, **kwargs) |
| |
| |
| class MyWrapperLoadTensor(MyLoadTensor): |
| @staticmethod |
| def __new__(cls, data: torch.Tensor): |
| t = torch.Tensor._make_wrapper_subclass( |
| cls, |
| data.size(), |
| dtype=data.dtype, |
| layout=data.layout, |
| device=data.device, |
| requires_grad=data.requires_grad, |
| strides=data.stride(), |
| storage_offset=data.storage_offset(), |
| ) |
| return t |
| |
| def __init__(self, data: torch.Tensor): |
| self._data = data |
| |
| def __repr__(self): |
| return f"MyWrapperLoadTensor({self._data.__repr__()})" |
| |
| @classmethod |
| def __torch_dispatch__(cls, func, types, args=(), kwargs=None): |
| def unwrap(t): |
| return t._data if isinstance(t, MyWrapperLoadTensor) else t |
| |
| def wrap(t): |
| return MyWrapperLoadTensor(t) if isinstance(t, torch.Tensor) else t |
| |
| kwargs = {} if kwargs is None else kwargs |
| out = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) |
| return tree_map(wrap, out) |
| |
| |
| class TestLoadStateDictSwap(TestCase): |
| @skipIfCrossRef |
| @skipIfTorchDynamo("Can't swap with dynamo as dynamo installs weakrefs") |
| @swap([True]) |
| @parametrize("assign", [True, False]) |
| def test_swap_subclass(self, assign): |
| def _create_model(subclass=None): |
| m = torch.nn.Linear(2, 3, bias=False) |
| m.register_buffer("buf", torch.randn(2, 3)) |
| if subclass is not None: |
| m.weight = torch.nn.Parameter(subclass(m.weight)) |
| m.buf = subclass(m.buf) |
| return m |
| |
| def _test(m_subclass=None, sd_subclass=None): |
| m = _create_model(m_subclass) |
| sd = _create_model(sd_subclass).state_dict() |
| m.load_state_dict(sd, assign=assign) |
| self.assertEqual(m.weight, sd["weight"]) |
| self.assertEqual(m.buf, sd["buf"]) |
| self.assertTrue(isinstance(m.weight, torch.nn.Parameter)) |
| self.assertTrue(not isinstance(m.buf, torch.nn.Parameter)) |
| |
| weight_type, buf_type = (torch.nn.Parameter, torch.Tensor) |
| if assign: |
| if sd_subclass is not None: |
| weight_type, buf_type = (sd_subclass, sd_subclass) |
| else: |
| if m_subclass is not None: |
| weight_type, buf_type = (m_subclass, m_subclass) |
| |
| self.assertTrue(type(m.weight) is weight_type) |
| self.assertTrue(type(m.buf) is buf_type) |
| |
| # (MyLoadTensor, MyWrapperLoadTensor) tests the behavior of (superclass, subclass) |
| subclasses = [None, MyLoadTensor, MyLoadTensor2, MyWrapperLoadTensor] |
| for m_s, sd_s in product(subclasses, subclasses): |
| _test(m_s, sd_s) |
| |
| # MyBrokenLoadTensor should error since its module_load doesn't call .detach() |
| with self.assertRaisesRegex( |
| RuntimeError, re.escape("Error(s) in loading state_dict for Linear:") |
| ): |
| _test(None, MyBrokenLoadTensor) |
| |
| |
| instantiate_parametrized_tests(TestLoadStateDict) |
| instantiate_parametrized_tests(TestLoadStateDictSwap) |
| |
| if __name__ == "__main__": |
| TestCase._default_dtype_check_enabled = True |
| run_tests() |