Make load_state_dict() more restrictive (#451)

The load_state_dict() function now raises an error if the argument
state_dict has extra keys or is missing keys.

Previously, load_state_dict() ignored extra and missing keys, which made
it hard to notice when you load an invalid state_dict. This could
happen, for example, if you save the state_dict for a DataParallel, but
load it into a single model.

The state_dict() function now only includes the Tensor data from the
paramters, which reduces checkpoint size by not saving gradients.
diff --git a/test/test_nn.py b/test/test_nn.py
index 8bf016d..a9128fb 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -738,18 +738,20 @@
             param = net
             for component in k.split('.'):
                 param = getattr(param, component)
+                if isinstance(param, Parameter):
+                    param = param.data
             self.assertIs(v, param)
 
         l = nn.Linear(5, 5)
         state_dict = l.state_dict()
         self.assertEqual(len(state_dict), 2)
-        self.assertIs(state_dict['weight'], l.weight)
-        self.assertIs(state_dict['bias'], l.bias)
+        self.assertIs(state_dict['weight'], l.weight.data)
+        self.assertIs(state_dict['bias'], l.bias.data)
 
     def test_load_state_dict(self):
         l = nn.Linear(5, 5)
         block = nn.Container(
-            conv1=nn.Conv2d(3, 3, 3, bias=False),
+            conv1=nn.Conv2d(3, 3, 3, bias=True),
             conv2=nn.Conv2d(3, 3, 3, bias=False),
         )
         net = nn.Container(
@@ -759,22 +761,24 @@
             block=block,
             empty=None,
         )
-        state_dict = {
-            'linear1.weight': Parameter(torch.ones(5, 5)),
-            'block.conv1.bias': Parameter(torch.range(1, 3)),
-            'block.conv2.bias': None,
+        state_dict = net.state_dict()
+        state_dict.update({
+            'linear1.weight': torch.ones(5, 5),
+            'block.conv1.bias': torch.range(1, 3),
             'bn.running_mean': torch.randn(2),
-        }
+        })
         net.load_state_dict(state_dict)
-        self.assertIs(net.linear1.weight, state_dict['linear1.weight'])
-        self.assertIs(net.block.conv1.bias, state_dict['block.conv1.bias'])
-        self.assertIs(net.block.conv2.bias, state_dict['block.conv2.bias'])
-        self.assertIs(net.bn.running_mean, state_dict['bn.running_mean'])
+        self.assertEqual(net.linear1.weight.data, state_dict['linear1.weight'])
+        self.assertEqual(net.block.conv1.bias.data, state_dict['block.conv1.bias'])
+        self.assertEqual(net.bn.running_mean, state_dict['bn.running_mean'])
 
-        state_dict = {
-            'linear1.weight': torch.ones(5, 5)
-        }
-        self.assertRaises(TypeError, lambda: net.load_state_dict(state_dict))
+        state_dict = net.state_dict()
+        state_dict.update({'extra': torch.ones(5)})
+        self.assertRaises(KeyError, lambda: net.load_state_dict(state_dict))
+
+        state_dict = net.state_dict()
+        del state_dict['linear1.weight']
+        self.assertRaises(KeyError, lambda: net.load_state_dict(state_dict))
 
     def test_parameter_assignment(self):
         l = nn.Linear(5, 5)
diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py
index a90dee2..420b01a 100644
--- a/torch/nn/modules/container.py
+++ b/torch/nn/modules/container.py
@@ -93,41 +93,12 @@
             Module.__delattr__(self, name)
 
     def state_dict(self, destination=None, prefix=''):
-        """Returns a dictionary containing a whole state of the model.
-
-        Both parameters and persistent buffers (e.g. running averages) are
-        included. Keys are computed using a natural Python's indexing syntax
-        (e.g. 'subcontainer.module.weight'), excluding ``self``.
-
-        Example:
-            >>> print(model.state_dict().keys())
-            ['conv1.bias', 'conv1.weight']
-        """
         result = super(Container, self).state_dict(destination, prefix)
         for name, module in self._modules.items():
             if module is not None:
                 module.state_dict(result, prefix + name + '.')
         return result
 
-    def load_state_dict(self, state_dict, prefix=''):
-        """Replaces model parameters using values from a given state_dict.
-
-        Copies all state_dict entries, where keys match any of the submodules.
-        For example, if the state_dict has an entry ``'conv44.weight'``, but
-        if the container does not have any submodule named ``'conv44'``, then
-        such entry will be ignored. However, once a module is found, this will
-        load all values from the state dict (including such that weren't
-        registered before loading).
-
-        Arguments:
-            state_dict (dict): A dict containing loaded parameters and
-                persistent buffers.
-        """
-        super(Container, self).load_state_dict(state_dict)
-        for name, module in self._modules.items():
-            if module is not None:
-                module.load_state_dict(state_dict, prefix + name + '.')
-
     def parameters(self, memo=None):
         """Returns an iterator over model parameters (including submodules).
 
diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py
index 70cb26f..5e79dd6 100644
--- a/torch/nn/modules/module.py
+++ b/torch/nn/modules/module.py
@@ -209,35 +209,42 @@
         included. Keys are corresponding parameter and buffer names.
 
         Example:
-            >>> print(module.state_dict().keys())
+            >>> module.state_dict().keys()
             ['bias', 'weight']
         """
         if destination is None:
             destination = OrderedDict()
-        for name, param in chain(self._buffers.items(), self._parameters.items()):
+        for name, param in self._parameters.items():
             if param is not None:
-                destination[prefix + name] = param
+                destination[prefix + name] = param.data
+        for name, buf in self._buffers.items():
+            if buf is not None:
+                destination[prefix + name] = buf
         return destination
 
-    def load_state_dict(self, state_dict, prefix=''):
-        """Replaces module parameters using values from a given state_dict.
-
-        This will load all values from the state dict (including such that
-        weren't registered before loading).
+    def load_state_dict(self, state_dict):
+        """Copies parameters and buffers from :attr:`state_dict` into
+        this module and its descendants. The keys of :attr:`state_dict` must
+        exactly match the keys returned by this module's :func:`state_dict()`
+        fuction.
 
         Arguments:
-            state_dict (dict): A dict containing loaded parameters and
+            state_dict (dict): A dict containing parameters and
                 persistent buffers.
         """
-        for name, param in self._parameters.items():
-            new_param = state_dict.get(prefix + name, param)
-            if not isinstance(new_param, Parameter) and new_param is not None:
-                raise TypeError(
-                    "expected torch.autograd.Parameter for key '{}' (got {})"
-                    .format(prefix + name, torch.typename(new_param)))
-            self._parameters[name] = new_param
-        for name, buf in self._buffers.items():
-            self._buffers[name] = state_dict.get(prefix + name, buf)
+        own_state = self.state_dict()
+        for name, param in state_dict.items():
+            if name not in own_state:
+                raise KeyError('unexpected key "{}" in state_dict'
+                               .format(name))
+            if isinstance(param, Parameter):
+                # backwards compatibility for serialized parameters
+                param = param.data
+            own_state[name].copy_(param)
+
+        missing = set(own_state.keys()) - set(state_dict.keys())
+        if len(missing) > 0:
+            raise KeyError('missing keys in state_dict: "{}"'.format(missing))
 
     def parameters(self, memo=None):
         """Returns an iterator over module parameters.