Revert "Load state dict post hook"

This reverts commit 56bed0dcfe7ca9047e5c95a6f3d7fcb0ec403b0c.

Reverted https://github.com/pytorch/pytorch/pull/76823 on behalf of https://github.com/rohan-varma
diff --git a/test/test_nn.py b/test/test_nn.py
index 9dfb307..91c5b70 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -21557,62 +21557,6 @@
             m.load_state_dict(state_dict)
             self.assertEqual(2, hook_called)
 
-    def test_load_state_dict_post_hook(self):
-        hook_called = 0
-
-        class MyModule(nn.Module):
-            def __init__(self):
-                super(MyModule, self).__init__()
-                self.foo = torch.nn.Parameter(torch.rand(10))
-
-            def my_post_load_hook(self, module, incompatible_keys):
-                assert module is self
-                nonlocal hook_called
-                incompatible_keys.missing_keys.append("foo")
-                incompatible_keys.unexpected_keys.append("bar")
-                hook_called += 1
-
-        nested = MyModule()
-        wrapped = nn.ModuleList([nested])
-        handle = nested.register_load_state_dict_post_hook(
-            nested.my_post_load_hook,
-        )
-        # Hook must be called even if it is wrapped
-        ret = wrapped.load_state_dict(wrapped.state_dict(), strict=False)
-        self.assertEqual(hook_called, 1)
-        # Ensure that the hook modified missing_keys and unexpected_keys
-        missing = ret.missing_keys
-        unexpected = ret.unexpected_keys
-        self.assertEqual(missing, ["foo"])
-        self.assertEqual(unexpected, ["bar"])
-        # When called with strict=True, the error raised should mention the
-        # missing and unexpected keys the hook added.
-        with self.assertRaisesRegex(RuntimeError, "foo.*\n.*bar"):
-            wrapped.load_state_dict(wrapped.state_dict(), strict=True)
-        self.assertEqual(hook_called, 2)
-        # Removing the hook via handle.remove() should cause it not to
-        # fire anymore.
-        handle.remove()
-        # Hook did not run so it should not have added any keys
-        ret = wrapped.load_state_dict(wrapped.state_dict(), strict=False)
-        self.assertEqual(ret.missing_keys, [])
-        self.assertEqual(ret.unexpected_keys, [])
-        # hook_called should not have been incremented
-        self.assertEqual(hook_called, 2)
-
-        def load_hook_clear_incompatible(module, incompatible_keys):
-            incompatible_keys.missing_keys.clear()
-            incompatible_keys.unexpected_keys.clear()
-
-        nested.register_load_state_dict_post_hook(load_hook_clear_incompatible)
-        state_dict = wrapped.state_dict()
-        state_dict["extra"] = torch.ones(1)
-        # load state_dict with strict=True should not throw.
-        ret = wrapped.load_state_dict(state_dict, strict=True)
-        # explicitly ensure that the post hook clearned out incompatible_keys
-        self.assertEqual([], ret.missing_keys)
-        self.assertEqual([], ret.unexpected_keys)
-
 
 instantiate_device_type_tests(TestNNDeviceType, globals())
 instantiate_parametrized_tests(TestNN)
diff --git a/torch/distributed/nn/api/remote_module.py b/torch/distributed/nn/api/remote_module.py
index 9316046..147f5ac 100644
--- a/torch/distributed/nn/api/remote_module.py
+++ b/torch/distributed/nn/api/remote_module.py
@@ -65,7 +65,6 @@
     "_forward_pre_hooks",
     "_state_dict_hooks",
     "_load_state_dict_pre_hooks",
-    "_load_state_dict_post_hooks",
     "_modules",
     # The two attributes below are generated methods, not available at pickling time.
     "forward_async",
diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py
index 6291327..8dc32f5 100644
--- a/torch/nn/modules/module.py
+++ b/torch/nn/modules/module.py
@@ -282,7 +282,6 @@
         self._forward_pre_hooks: Dict[int, Callable] = OrderedDict()
         self._state_dict_hooks: Dict[int, Callable] = OrderedDict()
         self._load_state_dict_pre_hooks: Dict[int, Callable] = OrderedDict()
-        self._load_state_dict_post_hooks: Dict[int, Callable] = OrderedDict()
         self._modules: Dict[str, Optional['Module']] = OrderedDict()
 
     forward: Callable[..., Any] = _forward_unimplemented
@@ -1198,8 +1197,6 @@
             self._state_dict_hooks = OrderedDict()
         if '_load_state_dict_pre_hooks' not in self.__dict__:
             self._load_state_dict_pre_hooks = OrderedDict()
-        if '_load_state_dict_post_hooks' not in self.__dict__:
-            self._load_state_dict_post_hooks = OrderedDict()
         if '_non_persistent_buffers_set' not in self.__dict__:
             self._non_persistent_buffers_set = set()
         if '_is_full_backward_hook' not in self.__dict__:
@@ -1420,37 +1417,6 @@
         self._load_state_dict_pre_hooks[handle.id] = hook
         return handle
 
-    def register_load_state_dict_post_hook(self, hook):
-        r"""Registers a post hook to be run after module's ``load_state_dict``
-        is called.
-
-        It should have the following signature::
-            hook(module, incompatible_keys) -> None
-
-        The ``module`` argument is the current module that this hook is registered
-        on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting
-        of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys``
-        is a ``list`` of ``str`` containing the missing keys and
-        ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys.
-
-        The given incompatible_keys can be modified inplace if needed.
-
-        Note that the checks performed when calling :func:`load_state_dict` with
-        ``strict=True`` are affected by modifications the hook makes to
-        ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either
-        set of keys will result in an error being thrown when ``strict=True``, and
-        clearning out both missing and unexpected keys will avoid an error.
-
-        Returns:
-            :class:`torch.utils.hooks.RemovableHandle`:
-                a handle that can be used to remove the added hook by calling
-                ``handle.remove()``
-        """
-        handle = hooks.RemovableHandle(self._load_state_dict_post_hooks)
-        self._load_state_dict_post_hooks[handle.id] = hook
-        return handle
-
-
     def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                               missing_keys, unexpected_keys, error_msgs):
         r"""Copies parameters and buffers from :attr:`state_dict` into only
@@ -1591,16 +1557,6 @@
                 if child is not None:
                     load(child, prefix + name + '.')
 
-            # Note that the hook can modify missing_keys and unexpected_keys.
-            incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
-            for hook in module._load_state_dict_post_hooks.values():
-                out = hook(module, incompatible_keys)
-                assert out is None, (
-                    "Hooks registered with ``register_load_state_dict_post_hook`` are not"
-                    "expected to return new values, if incompatible_keys need to be modified,"
-                    "it should be done inplace."
-                )
-
         load(self)
         del load