Allow arbitrary objects in state_dicts (#62976)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/62094
Introduces functionality for adding arbitrary objects to module state_dicts. To take advantage of this, the following functions can be defined on a module:
* `get_extra_state(self) -> dict` - Returns a dict defining any extra state this module wants to save
* `set_extra_state(self, state)` - Subsumes the given state within the module
In the details, a sub-dictionary is stored in the state_dict under the key `_extra_state` for each module that requires extra state.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62976
Reviewed By: heitorschueroff
Differential Revision: D30518657
Pulled By: jbschlosser
fbshipit-source-id: 5fb35ab8e3d36f35e3e96dcd4498f8c917d1f386
diff --git a/test/test_nn.py b/test/test_nn.py
index 43e105a..d577493 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -5465,6 +5465,92 @@
self.assertEqual(mm[0].param[0].item(), 10)
self.assertEqual(mm[0].sub.weight[0, 0].item(), 555)
+ def test_extra_state(self):
+
+ class SubModule(torch.nn.Module):
+ def __init__(self, foo):
+ super().__init__()
+ self.foo = foo
+
+ def get_extra_state(self):
+ return {
+ 'foo': self.foo
+ }
+
+ def set_extra_state(self, state):
+ self.foo = state['foo']
+
+ class MyModule(torch.nn.Module):
+ def __init__(self, foo, bar):
+ super().__init__()
+ self.sub = SubModule(foo)
+ self.bar = bar
+
+ def get_extra_state(self):
+ return {
+ 'bar': self.bar
+ }
+
+ def set_extra_state(self, state):
+ self.bar = state['bar']
+
+ # Ensure state_dict contains the extra state by loading it into another module.
+ m = MyModule(3, 'something')
+ m2 = MyModule(5, 'something else')
+ m2.load_state_dict(m.state_dict())
+ self.assertEqual(m.state_dict(), m2.state_dict())
+ self.assertEqual(m2.bar, m.bar)
+ self.assertEqual(m2.sub.foo, m.sub.foo)
+
+ def test_extra_state_non_dict(self):
+
+ class MyModule(torch.nn.Module):
+ def __init__(self, foo):
+ super().__init__()
+ self.foo = foo
+
+ def get_extra_state(self):
+ return self.foo
+
+ def set_extra_state(self, state):
+ self.foo = state
+
+ # Test various types of extra state.
+ for state in ('something', 5, MyModule(3)):
+ m = MyModule(state)
+ m2 = MyModule('something else')
+ m2.load_state_dict(m.state_dict())
+ self.assertEqual(m.state_dict(), m2.state_dict())
+ self.assertEqual(m.foo, m2.foo)
+
+ def test_extra_state_missing_set_extra_state(self):
+
+ class MyModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def get_extra_state(self):
+ return {
+ 'foo': 5
+ }
+
+ m = MyModule()
+ with self.assertRaisesRegex(RuntimeError, 'Unexpected key'):
+ m.load_state_dict(m.state_dict())
+
+ def test_extra_state_missing_get_extra_state(self):
+
+ class MyModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def set_extra_state(self):
+ pass
+
+ m = MyModule()
+ with self.assertRaisesRegex(RuntimeError, 'Missing key'):
+ m.load_state_dict(m.state_dict())
+
def test_parameter_assignment(self):
l = nn.Linear(5, 5)
diff --git a/torch/jit/_script.py b/torch/jit/_script.py
index 0c3e5ef..3d173ae 100644
--- a/torch/jit/_script.py
+++ b/torch/jit/_script.py
@@ -912,6 +912,8 @@
"_tracing_name",
"eval",
"train",
+ "get_extra_state",
+ "set_extra_state"
}
def _make_fail(name):
diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py
index 2376422..28b220e 100644
--- a/torch/nn/modules/module.py
+++ b/torch/nn/modules/module.py
@@ -46,6 +46,8 @@
_global_forward_pre_hooks: Dict[int, Callable] = OrderedDict()
_global_forward_hooks: Dict[int, Callable] = OrderedDict()
+_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
+
def register_module_forward_pre_hook(hook: Callable[..., None]) -> RemovableHandle:
r"""Registers a forward pre-hook common to all modules.
@@ -528,6 +530,41 @@
return buffer
+ def get_extra_state(self) -> Any:
+ """
+ Returns any extra state to include in the module's state_dict.
+ Implement this and a corresponding :func:`set_extra_state` for your module
+ if you need to store extra state. This function is called when building the
+ module's `state_dict()`.
+
+ Note that extra state should be pickleable to ensure working serialization
+ of the state_dict. We only provide provide backwards compatibility guarantees
+ for serializing Tensors; other objects may break backwards compatibility if
+ their serialized pickled form changes.
+
+ Returns:
+ object: Any extra state to store in the module's state_dict
+ """
+ raise RuntimeError(
+ "Reached a code path in Module.get_extra_state() that should never be called. "
+ "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.md "
+ "to report this bug.")
+
+ def set_extra_state(self, state: Any):
+ """
+ This function is called from :func:`load_state_dict` to handle any extra state
+ found within the `state_dict`. Implement this function and a corresponding
+ :func:`get_extra_state` for your module if you need to store extra state within its
+ `state_dict`.
+
+ Args:
+ state (dict): Extra state from the `state_dict`
+ """
+ raise RuntimeError(
+ "Reached a code path in Module.set_extra_state() that should never be called. "
+ "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.md "
+ "to report this bug.")
+
def _apply(self, fn):
for module in self.children():
module._apply(fn)
@@ -1228,6 +1265,9 @@
for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set:
destination[prefix + name] = buf if keep_vars else buf.detach()
+ extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
+ if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state:
+ destination[extra_state_key] = self.get_extra_state()
# The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns
# back that same object. But if they pass nothing, an `OrederedDict` is created and returned.
@@ -1365,9 +1405,18 @@
elif strict:
missing_keys.append(key)
+ extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
+ if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
+ if extra_state_key in state_dict:
+ self.set_extra_state(state_dict[extra_state_key])
+ elif strict:
+ missing_keys.append(extra_state_key)
+ elif strict and (extra_state_key in state_dict):
+ unexpected_keys.append(extra_state_key)
+
if strict:
for key in state_dict.keys():
- if key.startswith(prefix):
+ if key.startswith(prefix) and key != extra_state_key:
input_name = key[len(prefix):]
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state: