[FSDP] New fix for composing with other module wrappers (#87950)

We change `.module` to pass through `ActivationWrapper` directly to the inner wrapped module. This should fix the state dict issues.

Given the invariant that `.module` always returns the inner wrapped module, FSDP always registers the `FlatParameter` on the inner wrapped module, regardless of if there is an intermediate `ActivationWrapper` or not. This avoids casing on whether `ActivationWrapper` is added before or after FSDP construction.

This PR removes the added unit test in `test_fsdp_misc.py` for changing the wrapped module because I would rather not complicated `_lazy_init()` logic just to support that kind of adversarial behavior. The user should not be swapping out the wrapped module arbitrarily or deleting the `FlatParameter`. I mainly had those tests to make sure that all branches of the code I added was correct.

Differential Revision: [D40799961](https://our.internmc.facebook.com/intern/diff/D40799961)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87950
Approved by: https://github.com/zhaojuanmao
diff --git a/test/distributed/fsdp/test_fsdp_misc.py b/test/distributed/fsdp/test_fsdp_misc.py
index b57f5a1..98cd648 100644
--- a/test/distributed/fsdp/test_fsdp_misc.py
+++ b/test/distributed/fsdp/test_fsdp_misc.py
@@ -9,22 +9,12 @@
 import torch
 import torch.distributed as dist
 import torch.nn as nn
-from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
-    _CHECKPOINT_PREFIX,
-    apply_activation_checkpointing,
-    checkpoint_wrapper,
-    CheckpointImpl,
-)
 from torch.distributed.fsdp import (
     CPUOffload,
     FlatParameter,
     FullyShardedDataParallel as FSDP,
     ShardingStrategy,
 )
-from torch.distributed.fsdp.fully_sharded_data_parallel import (
-    FLAT_PARAM,
-    FSDP_WRAPPED_MODULE,
-)
 from torch.distributed.fsdp.wrap import always_wrap_policy, transformer_auto_wrap_policy
 from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
 from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
@@ -499,90 +489,6 @@
                 fsdp, process_group=self.process_group, assert_fn=self.assertEqual
             )
 
-    @skip_if_lt_x_gpu(2)
-    def test_change_wrapped_module_after_ctor(self):
-        """
-        Tests changing an FSDP instance's wrapped module after the FSDP
-        constructor.
-        """
-        dist.set_debug_level(dist.DebugLevel.DETAIL)
-
-        class Model(nn.Module):
-            def __init__(self) -> None:
-                super().__init__()
-                self.seq1 = nn.Sequential(
-                    nn.Linear(5, 5),
-                    nn.Linear(5, 5),
-                )
-                self.seq2 = nn.Sequential(nn.Linear(5, 5))
-                self.lin = nn.Linear(5, 5)
-                self.relu = nn.ReLU()
-
-            def forward(self, x: torch.Tensor) -> torch.Tensor:
-                return self.lin(self.relu(self.seq2(self.relu(self.seq1(x)))))
-
-        def get_fsdp_model():
-            fsdp_kwargs = {"use_orig_params": False}
-            model = Model().cuda()
-            model.seq1 = FSDP(model.seq1, **fsdp_kwargs)
-            model.seq2[0] = FSDP(model.seq2[0], **fsdp_kwargs)
-            model = FSDP(model, **fsdp_kwargs)
-            return model
-
-        # Wrap with `CheckpointWrapper` *after* FSDP construction
-        model = get_fsdp_model()
-        non_reentrant_wrapper = functools.partial(
-            checkpoint_wrapper,
-            offload_to_cpu=False,
-            checkpoint_impl=CheckpointImpl.NO_REENTRANT,
-        )
-        apply_activation_checkpointing(
-            model,
-            checkpoint_wrapper_fn=non_reentrant_wrapper,
-            check_fn=lambda submodule: isinstance(submodule, nn.Linear),
-        )
-
-        # Check that `seq2[0]` only has a single `FlatParameter` registered and
-        # that it has the `CheckpointWrapper` prefix in its FQN since it was
-        # registered to the `Linear` wrapped module in the FSDP constructor and
-        # only wrapped with `CheckpointWrapper` after
-        seq2_0_named_params = list(model.seq2[0].named_parameters())
-        self.assertEqual(len(seq2_0_named_params), 1)
-        self.assertTrue(type(seq2_0_named_params[0][1]) is FlatParameter)
-        self.assertTrue(_CHECKPOINT_PREFIX in seq2_0_named_params[0][0])
-
-        # Trigger the re-registration via `_lazy_init()`, and check for a
-        # warning, which is only emitted for DETAIL
-        with self.assertWarnsRegex(
-            UserWarning,
-            "The FSDP wrapped module changed from Linear.*to CheckpointWrapper",
-        ):
-            model._lazy_init()
-
-        # Check that now the `FlatParameter` is registered to the
-        # `CheckpointWrapper`, which is now the new wrapped module
-        seq2_0_named_params = list(model.seq2[0].named_parameters())
-        self.assertEqual(len(seq2_0_named_params), 1)
-        self.assertTrue(type(seq2_0_named_params[0][1]) is FlatParameter)
-        self.assertFalse(_CHECKPOINT_PREFIX in seq2_0_named_params[0][0])
-        self.assertFalse(isinstance(model.seq2[0].module, nn.Linear))
-
-        # Check that replacing a module *after* FSDP construction errors
-        model = get_fsdp_model()
-        # NOTE: Setting `model.seq2[0].module = nn.Linear(3, 3)` does not save
-        # to the FSDP instance's `module` attribute since `module` is a
-        # property, meaning that it would not actually change the wrapped
-        # module, so we use `setattr()` like in `_recursive_wrap()`.
-        setattr(model.seq2[0], FSDP_WRAPPED_MODULE, nn.Linear(3, 3))
-        with self.assertRaisesRegex(RuntimeError, "are invalid behavior"):
-            model._lazy_init()
-
-        # Check that deleting the `FlatParameter` errors
-        model = get_fsdp_model()
-        delattr(model.seq2[0].module, FLAT_PARAM)
-        with self.assertRaisesRegex(RuntimeError, "are invalid behavior"):
-            model._lazy_init()
-
 
 instantiate_parametrized_tests(TestFSDPMisc)
 
diff --git a/test/distributed/fsdp/test_fsdp_state_dict.py b/test/distributed/fsdp/test_fsdp_state_dict.py
index b8cbae5..1334050 100644
--- a/test/distributed/fsdp/test_fsdp_state_dict.py
+++ b/test/distributed/fsdp/test_fsdp_state_dict.py
@@ -285,6 +285,69 @@
                 self._compare_models(model, model_new, self.assertEqual)
 
     @skip_if_lt_x_gpu(2)
+    @parametrize("state_dict_type", _UNFLATTENED_STATE_DICT_IMPLS)
+    @parametrize("rank0_only_and_offload", [False, True])
+    def test_state_dict_with_manual_ac_wrapper(
+        self,
+        state_dict_type: str,
+        rank0_only_and_offload: bool,
+    ):
+        """
+        Tests saving and loading a state dict for a model manually wrapped with
+        ``FSDP(CheckpointWrapper(module))``, where the ``CheckpointWrapper`` is
+        wrapped before FSDP.
+
+        TODO: Investigate why the test above does not cover everything in this
+        test and de-duplicate afterwards.
+        """
+        if state_dict_type == "sharded_state_dict" and rank0_only_and_offload:
+            return  # not supported
+        model_ac = TransformerWithSharedParams.init(
+            self.process_group,
+            FSDPInitMode.NO_FSDP,
+            CUDAInitMode.CUDA_BEFORE,
+        )
+        # Manually wrap FSDP without AC
+        model_no_ac = deepcopy(model_ac)
+        for i, layer in enumerate(model_no_ac.transformer.encoder.layers):
+            model_no_ac.transformer.encoder.layers[i] = FSDP(layer)
+        for i, layer in enumerate(model_no_ac.transformer.decoder.layers):
+            model_no_ac.transformer.decoder.layers[i] = FSDP(layer)
+        model_no_ac.transformer = FSDP(model_no_ac.transformer)
+
+        # Manually wrap FSDP with AC as `FSDP(CheckpointWrapper(module))`
+        for i, layer in enumerate(model_ac.transformer.encoder.layers):
+            layer = checkpoint_wrapper(layer)
+            model_ac.transformer.encoder.layers[i] = FSDP(layer)
+        for i, layer in enumerate(model_ac.transformer.decoder.layers):
+            layer = checkpoint_wrapper(layer)
+            model_ac.transformer.decoder.layers[i] = FSDP(layer)
+        model_ac.transformer = FSDP(model_ac.transformer)
+
+        # Save, load, and compare the two models
+        with self._get_state_dict_mgr(
+            model_no_ac, state_dict_type, rank0_only_and_offload
+        ):
+            state_dict_no_ac = model_no_ac.state_dict()
+        with self._get_state_dict_mgr(
+            model_ac, state_dict_type, rank0_only_and_offload
+        ):
+            state_dict_ac = model_ac.state_dict()
+        self.assertEqual(state_dict_ac.keys(), state_dict_no_ac.keys())
+        if rank0_only_and_offload:
+            state_dict_no_ac = self._broadcast_state_dict(model_no_ac, state_dict_no_ac)
+            state_dict_ac = self._broadcast_state_dict(model_ac, state_dict_ac)
+        with self._get_state_dict_mgr(
+            model_no_ac, state_dict_type, rank0_only_and_offload
+        ):
+            model_no_ac.load_state_dict(state_dict_no_ac)
+        with self._get_state_dict_mgr(
+            model_ac, state_dict_type, rank0_only_and_offload
+        ):
+            model_ac.load_state_dict(state_dict_ac)
+        self._compare_models(model_ac, model_no_ac, self.assertEqual)
+
+    @skip_if_lt_x_gpu(2)
     @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)
     def test_state_dict_with_shared_parameters(self, state_dict_type):
         auto_wrap_policy = partial(
diff --git a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py
index 9e72fb6..35f8acf 100644
--- a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py
+++ b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py
@@ -8,6 +8,8 @@
 from torch.distributed.utils import _pack_kwargs, _replace_by_prefix, _unpack_kwargs
 from torch.utils.checkpoint import checkpoint as torch_utils_checkpoint
 
+# TODO: Refactor `_CHECKPOINT_PREFIX` to include the trailing '.' like FSDP
+_CHECKPOINT_WRAPPED_MODULE = "_checkpoint_wrapped_module"
 _CHECKPOINT_PREFIX = "_checkpoint_wrapped_module"
 
 class CheckpointImpl(Enum):
diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py
index 21b7787..0fd6019 100644
--- a/torch/distributed/fsdp/fully_sharded_data_parallel.py
+++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py
@@ -34,6 +34,8 @@
 from torch.distributed import ProcessGroup
 from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
     _CHECKPOINT_PREFIX,
+    _CHECKPOINT_WRAPPED_MODULE,
+    ActivationWrapper,
 )
 from torch.distributed.algorithms._comm_hooks import default_hooks, LOW_PRECISION_HOOKS
 from torch.distributed.distributed_c10d import _get_default_group
@@ -1634,6 +1636,10 @@
         """
         Returns the wrapped module (like :class:`DistributedDataParallel`).
         """
+        # FSDP's `.module` must refer to the innermost wrapped module when
+        # composing with other module wrappers in order for state dict to work
+        if isinstance(self._fsdp_wrapped_module, ActivationWrapper):
+            return getattr(self._fsdp_wrapped_module, _CHECKPOINT_WRAPPED_MODULE)
         return self._fsdp_wrapped_module
 
     @property
@@ -1884,44 +1890,6 @@
         # to non-root instances
         inconsistent_limit_all_gathers = False
         for fsdp_module in self.fsdp_modules(self):
-            if not fsdp_module._use_orig_params and fsdp_module._has_params:
-                # Check if the wrapped module changed after construction
-                # (e.g. applying the activation checkpointing wrapper) and
-                # if so, de-register the `FlatParameter` from the old
-                # wrapped module and register it to the new wrapped module
-                # NOTE: The `FlatParameter`'s FQN metadata is not updated, so
-                # any added wrappers must clean their prefixes from FQNs.
-                flat_param = fsdp_module._handles[0].flat_param
-                target_submodule = None
-                target_name = None
-                for submodule in fsdp_module.modules():
-                    for param_name, param in submodule._parameters.items():
-                        if flat_param is param:  # found registered `FlatParameter`
-                            target_submodule = submodule
-                            target_name = param_name
-                            break
-                    if target_submodule is not None:
-                        break
-                if (
-                    target_submodule is not None
-                    and target_submodule is not fsdp_module.module
-                ):
-                    assert target_name is not None
-                    if fsdp_module._debug_level == dist.DebugLevel.DETAIL:
-                        warnings.warn(
-                            "The FSDP wrapped module changed from "
-                            f"{target_submodule} to {fsdp_module.module} on "
-                            f"rank {fsdp_module.rank}. {fsdp_module}"
-                        )
-                    target_submodule._parameters.pop(target_name)  # de-register
-                    fsdp_module._register_flat_param()  # re-register
-                elif target_submodule is None:
-                    raise RuntimeError(
-                        "Either the FSDP wrapped module was removed from "
-                        "the model or its `FlatParameter` was manually "
-                        f"de-registered on rank {fsdp_module.rank}. Both of "
-                        f"these are invalid behavior. {fsdp_module}"
-                    )
             if fsdp_module is not self:
                 # Relax the assert for non-root FSDP instances in case the
                 # nested initialized module is wrapped again in FSDP later (e.g.