Revert "Load state dict post hook"
This reverts commit 56bed0dcfe7ca9047e5c95a6f3d7fcb0ec403b0c.
Reverted https://github.com/pytorch/pytorch/pull/76823 on behalf of https://github.com/rohan-varma
diff --git a/test/test_nn.py b/test/test_nn.py
index 9dfb307..91c5b70 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -21557,62 +21557,6 @@
m.load_state_dict(state_dict)
self.assertEqual(2, hook_called)
- def test_load_state_dict_post_hook(self):
- hook_called = 0
-
- class MyModule(nn.Module):
- def __init__(self):
- super(MyModule, self).__init__()
- self.foo = torch.nn.Parameter(torch.rand(10))
-
- def my_post_load_hook(self, module, incompatible_keys):
- assert module is self
- nonlocal hook_called
- incompatible_keys.missing_keys.append("foo")
- incompatible_keys.unexpected_keys.append("bar")
- hook_called += 1
-
- nested = MyModule()
- wrapped = nn.ModuleList([nested])
- handle = nested.register_load_state_dict_post_hook(
- nested.my_post_load_hook,
- )
- # Hook must be called even if it is wrapped
- ret = wrapped.load_state_dict(wrapped.state_dict(), strict=False)
- self.assertEqual(hook_called, 1)
- # Ensure that the hook modified missing_keys and unexpected_keys
- missing = ret.missing_keys
- unexpected = ret.unexpected_keys
- self.assertEqual(missing, ["foo"])
- self.assertEqual(unexpected, ["bar"])
- # When called with strict=True, the error raised should mention the
- # missing and unexpected keys the hook added.
- with self.assertRaisesRegex(RuntimeError, "foo.*\n.*bar"):
- wrapped.load_state_dict(wrapped.state_dict(), strict=True)
- self.assertEqual(hook_called, 2)
- # Removing the hook via handle.remove() should cause it not to
- # fire anymore.
- handle.remove()
- # Hook did not run so it should not have added any keys
- ret = wrapped.load_state_dict(wrapped.state_dict(), strict=False)
- self.assertEqual(ret.missing_keys, [])
- self.assertEqual(ret.unexpected_keys, [])
- # hook_called should not have been incremented
- self.assertEqual(hook_called, 2)
-
- def load_hook_clear_incompatible(module, incompatible_keys):
- incompatible_keys.missing_keys.clear()
- incompatible_keys.unexpected_keys.clear()
-
- nested.register_load_state_dict_post_hook(load_hook_clear_incompatible)
- state_dict = wrapped.state_dict()
- state_dict["extra"] = torch.ones(1)
- # load state_dict with strict=True should not throw.
- ret = wrapped.load_state_dict(state_dict, strict=True)
- # explicitly ensure that the post hook clearned out incompatible_keys
- self.assertEqual([], ret.missing_keys)
- self.assertEqual([], ret.unexpected_keys)
-
instantiate_device_type_tests(TestNNDeviceType, globals())
instantiate_parametrized_tests(TestNN)
diff --git a/torch/distributed/nn/api/remote_module.py b/torch/distributed/nn/api/remote_module.py
index 9316046..147f5ac 100644
--- a/torch/distributed/nn/api/remote_module.py
+++ b/torch/distributed/nn/api/remote_module.py
@@ -65,7 +65,6 @@
"_forward_pre_hooks",
"_state_dict_hooks",
"_load_state_dict_pre_hooks",
- "_load_state_dict_post_hooks",
"_modules",
# The two attributes below are generated methods, not available at pickling time.
"forward_async",
diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py
index 6291327..8dc32f5 100644
--- a/torch/nn/modules/module.py
+++ b/torch/nn/modules/module.py
@@ -282,7 +282,6 @@
self._forward_pre_hooks: Dict[int, Callable] = OrderedDict()
self._state_dict_hooks: Dict[int, Callable] = OrderedDict()
self._load_state_dict_pre_hooks: Dict[int, Callable] = OrderedDict()
- self._load_state_dict_post_hooks: Dict[int, Callable] = OrderedDict()
self._modules: Dict[str, Optional['Module']] = OrderedDict()
forward: Callable[..., Any] = _forward_unimplemented
@@ -1198,8 +1197,6 @@
self._state_dict_hooks = OrderedDict()
if '_load_state_dict_pre_hooks' not in self.__dict__:
self._load_state_dict_pre_hooks = OrderedDict()
- if '_load_state_dict_post_hooks' not in self.__dict__:
- self._load_state_dict_post_hooks = OrderedDict()
if '_non_persistent_buffers_set' not in self.__dict__:
self._non_persistent_buffers_set = set()
if '_is_full_backward_hook' not in self.__dict__:
@@ -1420,37 +1417,6 @@
self._load_state_dict_pre_hooks[handle.id] = hook
return handle
- def register_load_state_dict_post_hook(self, hook):
- r"""Registers a post hook to be run after module's ``load_state_dict``
- is called.
-
- It should have the following signature::
- hook(module, incompatible_keys) -> None
-
- The ``module`` argument is the current module that this hook is registered
- on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting
- of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys``
- is a ``list`` of ``str`` containing the missing keys and
- ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys.
-
- The given incompatible_keys can be modified inplace if needed.
-
- Note that the checks performed when calling :func:`load_state_dict` with
- ``strict=True`` are affected by modifications the hook makes to
- ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either
- set of keys will result in an error being thrown when ``strict=True``, and
- clearning out both missing and unexpected keys will avoid an error.
-
- Returns:
- :class:`torch.utils.hooks.RemovableHandle`:
- a handle that can be used to remove the added hook by calling
- ``handle.remove()``
- """
- handle = hooks.RemovableHandle(self._load_state_dict_post_hooks)
- self._load_state_dict_post_hooks[handle.id] = hook
- return handle
-
-
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
r"""Copies parameters and buffers from :attr:`state_dict` into only
@@ -1591,16 +1557,6 @@
if child is not None:
load(child, prefix + name + '.')
- # Note that the hook can modify missing_keys and unexpected_keys.
- incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
- for hook in module._load_state_dict_post_hooks.values():
- out = hook(module, incompatible_keys)
- assert out is None, (
- "Hooks registered with ``register_load_state_dict_post_hook`` are not"
- "expected to return new values, if incompatible_keys need to be modified,"
- "it should be done inplace."
- )
-
load(self)
del load