Allow arbitrary objects in state_dicts (#62976)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/62094

Introduces functionality for adding arbitrary objects to module state_dicts. To take advantage of this, the following functions can be defined on a module:
* `get_extra_state(self) -> dict` - Returns a dict defining any extra state this module wants to save
* `set_extra_state(self, state)` - Subsumes the given state within the module

In the details, a sub-dictionary is stored in the state_dict under the key `_extra_state` for each module that requires extra state.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62976

Reviewed By: heitorschueroff

Differential Revision: D30518657

Pulled By: jbschlosser

fbshipit-source-id: 5fb35ab8e3d36f35e3e96dcd4498f8c917d1f386
diff --git a/test/test_nn.py b/test/test_nn.py
index 43e105a..d577493 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -5465,6 +5465,92 @@
         self.assertEqual(mm[0].param[0].item(), 10)
         self.assertEqual(mm[0].sub.weight[0, 0].item(), 555)
 
+    def test_extra_state(self):
+
+        class SubModule(torch.nn.Module):
+            def __init__(self, foo):
+                super().__init__()
+                self.foo = foo
+
+            def get_extra_state(self):
+                return {
+                    'foo': self.foo
+                }
+
+            def set_extra_state(self, state):
+                self.foo = state['foo']
+
+        class MyModule(torch.nn.Module):
+            def __init__(self, foo, bar):
+                super().__init__()
+                self.sub = SubModule(foo)
+                self.bar = bar
+
+            def get_extra_state(self):
+                return {
+                    'bar': self.bar
+                }
+
+            def set_extra_state(self, state):
+                self.bar = state['bar']
+
+        # Ensure state_dict contains the extra state by loading it into another module.
+        m = MyModule(3, 'something')
+        m2 = MyModule(5, 'something else')
+        m2.load_state_dict(m.state_dict())
+        self.assertEqual(m.state_dict(), m2.state_dict())
+        self.assertEqual(m2.bar, m.bar)
+        self.assertEqual(m2.sub.foo, m.sub.foo)
+
+    def test_extra_state_non_dict(self):
+
+        class MyModule(torch.nn.Module):
+            def __init__(self, foo):
+                super().__init__()
+                self.foo = foo
+
+            def get_extra_state(self):
+                return self.foo
+
+            def set_extra_state(self, state):
+                self.foo = state
+
+        # Test various types of extra state.
+        for state in ('something', 5, MyModule(3)):
+            m = MyModule(state)
+            m2 = MyModule('something else')
+            m2.load_state_dict(m.state_dict())
+            self.assertEqual(m.state_dict(), m2.state_dict())
+            self.assertEqual(m.foo, m2.foo)
+
+    def test_extra_state_missing_set_extra_state(self):
+
+        class MyModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+
+            def get_extra_state(self):
+                return {
+                    'foo': 5
+                }
+
+        m = MyModule()
+        with self.assertRaisesRegex(RuntimeError, 'Unexpected key'):
+            m.load_state_dict(m.state_dict())
+
+    def test_extra_state_missing_get_extra_state(self):
+
+        class MyModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+
+            def set_extra_state(self):
+                pass
+
+        m = MyModule()
+        with self.assertRaisesRegex(RuntimeError, 'Missing key'):
+            m.load_state_dict(m.state_dict())
+
     def test_parameter_assignment(self):
         l = nn.Linear(5, 5)
 
diff --git a/torch/jit/_script.py b/torch/jit/_script.py
index 0c3e5ef..3d173ae 100644
--- a/torch/jit/_script.py
+++ b/torch/jit/_script.py
@@ -912,6 +912,8 @@
         "_tracing_name",
         "eval",
         "train",
+        "get_extra_state",
+        "set_extra_state"
     }
 
     def _make_fail(name):
diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py
index 2376422..28b220e 100644
--- a/torch/nn/modules/module.py
+++ b/torch/nn/modules/module.py
@@ -46,6 +46,8 @@
 _global_forward_pre_hooks: Dict[int, Callable] = OrderedDict()
 _global_forward_hooks: Dict[int, Callable] = OrderedDict()
 
+_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
+
 
 def register_module_forward_pre_hook(hook: Callable[..., None]) -> RemovableHandle:
     r"""Registers a forward pre-hook common to all modules.
@@ -528,6 +530,41 @@
 
         return buffer
 
+    def get_extra_state(self) -> Any:
+        """
+        Returns any extra state to include in the module's state_dict.
+        Implement this and a corresponding :func:`set_extra_state` for your module
+        if you need to store extra state. This function is called when building the
+        module's `state_dict()`.
+
+        Note that extra state should be pickleable to ensure working serialization
+        of the state_dict. We only provide provide backwards compatibility guarantees
+        for serializing Tensors; other objects may break backwards compatibility if
+        their serialized pickled form changes.
+
+        Returns:
+            object: Any extra state to store in the module's state_dict
+        """
+        raise RuntimeError(
+            "Reached a code path in Module.get_extra_state() that should never be called. "
+            "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.md "
+            "to report this bug.")
+
+    def set_extra_state(self, state: Any):
+        """
+        This function is called from :func:`load_state_dict` to handle any extra state
+        found within the `state_dict`. Implement this function and a corresponding
+        :func:`get_extra_state` for your module if you need to store extra state within its
+        `state_dict`.
+
+        Args:
+            state (dict): Extra state from the `state_dict`
+        """
+        raise RuntimeError(
+            "Reached a code path in Module.set_extra_state() that should never be called. "
+            "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.md "
+            "to report this bug.")
+
     def _apply(self, fn):
         for module in self.children():
             module._apply(fn)
@@ -1228,6 +1265,9 @@
         for name, buf in self._buffers.items():
             if buf is not None and name not in self._non_persistent_buffers_set:
                 destination[prefix + name] = buf if keep_vars else buf.detach()
+        extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
+        if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state:
+            destination[extra_state_key] = self.get_extra_state()
 
     # The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns
     # back that same object. But if they pass nothing, an `OrederedDict` is created and returned.
@@ -1365,9 +1405,18 @@
             elif strict:
                 missing_keys.append(key)
 
+        extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
+        if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
+            if extra_state_key in state_dict:
+                self.set_extra_state(state_dict[extra_state_key])
+            elif strict:
+                missing_keys.append(extra_state_key)
+        elif strict and (extra_state_key in state_dict):
+            unexpected_keys.append(extra_state_key)
+
         if strict:
             for key in state_dict.keys():
-                if key.startswith(prefix):
+                if key.startswith(prefix) and key != extra_state_key:
                     input_name = key[len(prefix):]
                     input_name = input_name.split('.', 1)[0]  # get the name of param/buffer/child
                     if input_name not in self._modules and input_name not in local_state: