Make load_state_dict() more restrictive (#451)
The load_state_dict() function now raises an error if the argument
state_dict has extra keys or is missing keys.
Previously, load_state_dict() ignored extra and missing keys, which made
it hard to notice when you load an invalid state_dict. This could
happen, for example, if you save the state_dict for a DataParallel, but
load it into a single model.
The state_dict() function now only includes the Tensor data from the
paramters, which reduces checkpoint size by not saving gradients.
diff --git a/test/test_nn.py b/test/test_nn.py
index 8bf016d..a9128fb 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -738,18 +738,20 @@
param = net
for component in k.split('.'):
param = getattr(param, component)
+ if isinstance(param, Parameter):
+ param = param.data
self.assertIs(v, param)
l = nn.Linear(5, 5)
state_dict = l.state_dict()
self.assertEqual(len(state_dict), 2)
- self.assertIs(state_dict['weight'], l.weight)
- self.assertIs(state_dict['bias'], l.bias)
+ self.assertIs(state_dict['weight'], l.weight.data)
+ self.assertIs(state_dict['bias'], l.bias.data)
def test_load_state_dict(self):
l = nn.Linear(5, 5)
block = nn.Container(
- conv1=nn.Conv2d(3, 3, 3, bias=False),
+ conv1=nn.Conv2d(3, 3, 3, bias=True),
conv2=nn.Conv2d(3, 3, 3, bias=False),
)
net = nn.Container(
@@ -759,22 +761,24 @@
block=block,
empty=None,
)
- state_dict = {
- 'linear1.weight': Parameter(torch.ones(5, 5)),
- 'block.conv1.bias': Parameter(torch.range(1, 3)),
- 'block.conv2.bias': None,
+ state_dict = net.state_dict()
+ state_dict.update({
+ 'linear1.weight': torch.ones(5, 5),
+ 'block.conv1.bias': torch.range(1, 3),
'bn.running_mean': torch.randn(2),
- }
+ })
net.load_state_dict(state_dict)
- self.assertIs(net.linear1.weight, state_dict['linear1.weight'])
- self.assertIs(net.block.conv1.bias, state_dict['block.conv1.bias'])
- self.assertIs(net.block.conv2.bias, state_dict['block.conv2.bias'])
- self.assertIs(net.bn.running_mean, state_dict['bn.running_mean'])
+ self.assertEqual(net.linear1.weight.data, state_dict['linear1.weight'])
+ self.assertEqual(net.block.conv1.bias.data, state_dict['block.conv1.bias'])
+ self.assertEqual(net.bn.running_mean, state_dict['bn.running_mean'])
- state_dict = {
- 'linear1.weight': torch.ones(5, 5)
- }
- self.assertRaises(TypeError, lambda: net.load_state_dict(state_dict))
+ state_dict = net.state_dict()
+ state_dict.update({'extra': torch.ones(5)})
+ self.assertRaises(KeyError, lambda: net.load_state_dict(state_dict))
+
+ state_dict = net.state_dict()
+ del state_dict['linear1.weight']
+ self.assertRaises(KeyError, lambda: net.load_state_dict(state_dict))
def test_parameter_assignment(self):
l = nn.Linear(5, 5)
diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py
index a90dee2..420b01a 100644
--- a/torch/nn/modules/container.py
+++ b/torch/nn/modules/container.py
@@ -93,41 +93,12 @@
Module.__delattr__(self, name)
def state_dict(self, destination=None, prefix=''):
- """Returns a dictionary containing a whole state of the model.
-
- Both parameters and persistent buffers (e.g. running averages) are
- included. Keys are computed using a natural Python's indexing syntax
- (e.g. 'subcontainer.module.weight'), excluding ``self``.
-
- Example:
- >>> print(model.state_dict().keys())
- ['conv1.bias', 'conv1.weight']
- """
result = super(Container, self).state_dict(destination, prefix)
for name, module in self._modules.items():
if module is not None:
module.state_dict(result, prefix + name + '.')
return result
- def load_state_dict(self, state_dict, prefix=''):
- """Replaces model parameters using values from a given state_dict.
-
- Copies all state_dict entries, where keys match any of the submodules.
- For example, if the state_dict has an entry ``'conv44.weight'``, but
- if the container does not have any submodule named ``'conv44'``, then
- such entry will be ignored. However, once a module is found, this will
- load all values from the state dict (including such that weren't
- registered before loading).
-
- Arguments:
- state_dict (dict): A dict containing loaded parameters and
- persistent buffers.
- """
- super(Container, self).load_state_dict(state_dict)
- for name, module in self._modules.items():
- if module is not None:
- module.load_state_dict(state_dict, prefix + name + '.')
-
def parameters(self, memo=None):
"""Returns an iterator over model parameters (including submodules).
diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py
index 70cb26f..5e79dd6 100644
--- a/torch/nn/modules/module.py
+++ b/torch/nn/modules/module.py
@@ -209,35 +209,42 @@
included. Keys are corresponding parameter and buffer names.
Example:
- >>> print(module.state_dict().keys())
+ >>> module.state_dict().keys()
['bias', 'weight']
"""
if destination is None:
destination = OrderedDict()
- for name, param in chain(self._buffers.items(), self._parameters.items()):
+ for name, param in self._parameters.items():
if param is not None:
- destination[prefix + name] = param
+ destination[prefix + name] = param.data
+ for name, buf in self._buffers.items():
+ if buf is not None:
+ destination[prefix + name] = buf
return destination
- def load_state_dict(self, state_dict, prefix=''):
- """Replaces module parameters using values from a given state_dict.
-
- This will load all values from the state dict (including such that
- weren't registered before loading).
+ def load_state_dict(self, state_dict):
+ """Copies parameters and buffers from :attr:`state_dict` into
+ this module and its descendants. The keys of :attr:`state_dict` must
+ exactly match the keys returned by this module's :func:`state_dict()`
+ fuction.
Arguments:
- state_dict (dict): A dict containing loaded parameters and
+ state_dict (dict): A dict containing parameters and
persistent buffers.
"""
- for name, param in self._parameters.items():
- new_param = state_dict.get(prefix + name, param)
- if not isinstance(new_param, Parameter) and new_param is not None:
- raise TypeError(
- "expected torch.autograd.Parameter for key '{}' (got {})"
- .format(prefix + name, torch.typename(new_param)))
- self._parameters[name] = new_param
- for name, buf in self._buffers.items():
- self._buffers[name] = state_dict.get(prefix + name, buf)
+ own_state = self.state_dict()
+ for name, param in state_dict.items():
+ if name not in own_state:
+ raise KeyError('unexpected key "{}" in state_dict'
+ .format(name))
+ if isinstance(param, Parameter):
+ # backwards compatibility for serialized parameters
+ param = param.data
+ own_state[name].copy_(param)
+
+ missing = set(own_state.keys()) - set(state_dict.keys())
+ if len(missing) > 0:
+ raise KeyError('missing keys in state_dict: "{}"'.format(missing))
def parameters(self, memo=None):
"""Returns an iterator over module parameters.