[FSDP] New fix for composing with other module wrappers (#87950)
We change `.module` to pass through `ActivationWrapper` directly to the inner wrapped module. This should fix the state dict issues.
Given the invariant that `.module` always returns the inner wrapped module, FSDP always registers the `FlatParameter` on the inner wrapped module, regardless of if there is an intermediate `ActivationWrapper` or not. This avoids casing on whether `ActivationWrapper` is added before or after FSDP construction.
This PR removes the added unit test in `test_fsdp_misc.py` for changing the wrapped module because I would rather not complicated `_lazy_init()` logic just to support that kind of adversarial behavior. The user should not be swapping out the wrapped module arbitrarily or deleting the `FlatParameter`. I mainly had those tests to make sure that all branches of the code I added was correct.
Differential Revision: [D40799961](https://our.internmc.facebook.com/intern/diff/D40799961)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87950
Approved by: https://github.com/zhaojuanmao
diff --git a/test/distributed/fsdp/test_fsdp_misc.py b/test/distributed/fsdp/test_fsdp_misc.py
index b57f5a1..98cd648 100644
--- a/test/distributed/fsdp/test_fsdp_misc.py
+++ b/test/distributed/fsdp/test_fsdp_misc.py
@@ -9,22 +9,12 @@
import torch
import torch.distributed as dist
import torch.nn as nn
-from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
- _CHECKPOINT_PREFIX,
- apply_activation_checkpointing,
- checkpoint_wrapper,
- CheckpointImpl,
-)
from torch.distributed.fsdp import (
CPUOffload,
FlatParameter,
FullyShardedDataParallel as FSDP,
ShardingStrategy,
)
-from torch.distributed.fsdp.fully_sharded_data_parallel import (
- FLAT_PARAM,
- FSDP_WRAPPED_MODULE,
-)
from torch.distributed.fsdp.wrap import always_wrap_policy, transformer_auto_wrap_policy
from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
@@ -499,90 +489,6 @@
fsdp, process_group=self.process_group, assert_fn=self.assertEqual
)
- @skip_if_lt_x_gpu(2)
- def test_change_wrapped_module_after_ctor(self):
- """
- Tests changing an FSDP instance's wrapped module after the FSDP
- constructor.
- """
- dist.set_debug_level(dist.DebugLevel.DETAIL)
-
- class Model(nn.Module):
- def __init__(self) -> None:
- super().__init__()
- self.seq1 = nn.Sequential(
- nn.Linear(5, 5),
- nn.Linear(5, 5),
- )
- self.seq2 = nn.Sequential(nn.Linear(5, 5))
- self.lin = nn.Linear(5, 5)
- self.relu = nn.ReLU()
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return self.lin(self.relu(self.seq2(self.relu(self.seq1(x)))))
-
- def get_fsdp_model():
- fsdp_kwargs = {"use_orig_params": False}
- model = Model().cuda()
- model.seq1 = FSDP(model.seq1, **fsdp_kwargs)
- model.seq2[0] = FSDP(model.seq2[0], **fsdp_kwargs)
- model = FSDP(model, **fsdp_kwargs)
- return model
-
- # Wrap with `CheckpointWrapper` *after* FSDP construction
- model = get_fsdp_model()
- non_reentrant_wrapper = functools.partial(
- checkpoint_wrapper,
- offload_to_cpu=False,
- checkpoint_impl=CheckpointImpl.NO_REENTRANT,
- )
- apply_activation_checkpointing(
- model,
- checkpoint_wrapper_fn=non_reentrant_wrapper,
- check_fn=lambda submodule: isinstance(submodule, nn.Linear),
- )
-
- # Check that `seq2[0]` only has a single `FlatParameter` registered and
- # that it has the `CheckpointWrapper` prefix in its FQN since it was
- # registered to the `Linear` wrapped module in the FSDP constructor and
- # only wrapped with `CheckpointWrapper` after
- seq2_0_named_params = list(model.seq2[0].named_parameters())
- self.assertEqual(len(seq2_0_named_params), 1)
- self.assertTrue(type(seq2_0_named_params[0][1]) is FlatParameter)
- self.assertTrue(_CHECKPOINT_PREFIX in seq2_0_named_params[0][0])
-
- # Trigger the re-registration via `_lazy_init()`, and check for a
- # warning, which is only emitted for DETAIL
- with self.assertWarnsRegex(
- UserWarning,
- "The FSDP wrapped module changed from Linear.*to CheckpointWrapper",
- ):
- model._lazy_init()
-
- # Check that now the `FlatParameter` is registered to the
- # `CheckpointWrapper`, which is now the new wrapped module
- seq2_0_named_params = list(model.seq2[0].named_parameters())
- self.assertEqual(len(seq2_0_named_params), 1)
- self.assertTrue(type(seq2_0_named_params[0][1]) is FlatParameter)
- self.assertFalse(_CHECKPOINT_PREFIX in seq2_0_named_params[0][0])
- self.assertFalse(isinstance(model.seq2[0].module, nn.Linear))
-
- # Check that replacing a module *after* FSDP construction errors
- model = get_fsdp_model()
- # NOTE: Setting `model.seq2[0].module = nn.Linear(3, 3)` does not save
- # to the FSDP instance's `module` attribute since `module` is a
- # property, meaning that it would not actually change the wrapped
- # module, so we use `setattr()` like in `_recursive_wrap()`.
- setattr(model.seq2[0], FSDP_WRAPPED_MODULE, nn.Linear(3, 3))
- with self.assertRaisesRegex(RuntimeError, "are invalid behavior"):
- model._lazy_init()
-
- # Check that deleting the `FlatParameter` errors
- model = get_fsdp_model()
- delattr(model.seq2[0].module, FLAT_PARAM)
- with self.assertRaisesRegex(RuntimeError, "are invalid behavior"):
- model._lazy_init()
-
instantiate_parametrized_tests(TestFSDPMisc)
diff --git a/test/distributed/fsdp/test_fsdp_state_dict.py b/test/distributed/fsdp/test_fsdp_state_dict.py
index b8cbae5..1334050 100644
--- a/test/distributed/fsdp/test_fsdp_state_dict.py
+++ b/test/distributed/fsdp/test_fsdp_state_dict.py
@@ -285,6 +285,69 @@
self._compare_models(model, model_new, self.assertEqual)
@skip_if_lt_x_gpu(2)
+ @parametrize("state_dict_type", _UNFLATTENED_STATE_DICT_IMPLS)
+ @parametrize("rank0_only_and_offload", [False, True])
+ def test_state_dict_with_manual_ac_wrapper(
+ self,
+ state_dict_type: str,
+ rank0_only_and_offload: bool,
+ ):
+ """
+ Tests saving and loading a state dict for a model manually wrapped with
+ ``FSDP(CheckpointWrapper(module))``, where the ``CheckpointWrapper`` is
+ wrapped before FSDP.
+
+ TODO: Investigate why the test above does not cover everything in this
+ test and de-duplicate afterwards.
+ """
+ if state_dict_type == "sharded_state_dict" and rank0_only_and_offload:
+ return # not supported
+ model_ac = TransformerWithSharedParams.init(
+ self.process_group,
+ FSDPInitMode.NO_FSDP,
+ CUDAInitMode.CUDA_BEFORE,
+ )
+ # Manually wrap FSDP without AC
+ model_no_ac = deepcopy(model_ac)
+ for i, layer in enumerate(model_no_ac.transformer.encoder.layers):
+ model_no_ac.transformer.encoder.layers[i] = FSDP(layer)
+ for i, layer in enumerate(model_no_ac.transformer.decoder.layers):
+ model_no_ac.transformer.decoder.layers[i] = FSDP(layer)
+ model_no_ac.transformer = FSDP(model_no_ac.transformer)
+
+ # Manually wrap FSDP with AC as `FSDP(CheckpointWrapper(module))`
+ for i, layer in enumerate(model_ac.transformer.encoder.layers):
+ layer = checkpoint_wrapper(layer)
+ model_ac.transformer.encoder.layers[i] = FSDP(layer)
+ for i, layer in enumerate(model_ac.transformer.decoder.layers):
+ layer = checkpoint_wrapper(layer)
+ model_ac.transformer.decoder.layers[i] = FSDP(layer)
+ model_ac.transformer = FSDP(model_ac.transformer)
+
+ # Save, load, and compare the two models
+ with self._get_state_dict_mgr(
+ model_no_ac, state_dict_type, rank0_only_and_offload
+ ):
+ state_dict_no_ac = model_no_ac.state_dict()
+ with self._get_state_dict_mgr(
+ model_ac, state_dict_type, rank0_only_and_offload
+ ):
+ state_dict_ac = model_ac.state_dict()
+ self.assertEqual(state_dict_ac.keys(), state_dict_no_ac.keys())
+ if rank0_only_and_offload:
+ state_dict_no_ac = self._broadcast_state_dict(model_no_ac, state_dict_no_ac)
+ state_dict_ac = self._broadcast_state_dict(model_ac, state_dict_ac)
+ with self._get_state_dict_mgr(
+ model_no_ac, state_dict_type, rank0_only_and_offload
+ ):
+ model_no_ac.load_state_dict(state_dict_no_ac)
+ with self._get_state_dict_mgr(
+ model_ac, state_dict_type, rank0_only_and_offload
+ ):
+ model_ac.load_state_dict(state_dict_ac)
+ self._compare_models(model_ac, model_no_ac, self.assertEqual)
+
+ @skip_if_lt_x_gpu(2)
@parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)
def test_state_dict_with_shared_parameters(self, state_dict_type):
auto_wrap_policy = partial(
diff --git a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py
index 9e72fb6..35f8acf 100644
--- a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py
+++ b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py
@@ -8,6 +8,8 @@
from torch.distributed.utils import _pack_kwargs, _replace_by_prefix, _unpack_kwargs
from torch.utils.checkpoint import checkpoint as torch_utils_checkpoint
+# TODO: Refactor `_CHECKPOINT_PREFIX` to include the trailing '.' like FSDP
+_CHECKPOINT_WRAPPED_MODULE = "_checkpoint_wrapped_module"
_CHECKPOINT_PREFIX = "_checkpoint_wrapped_module"
class CheckpointImpl(Enum):
diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py
index 21b7787..0fd6019 100644
--- a/torch/distributed/fsdp/fully_sharded_data_parallel.py
+++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py
@@ -34,6 +34,8 @@
from torch.distributed import ProcessGroup
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
_CHECKPOINT_PREFIX,
+ _CHECKPOINT_WRAPPED_MODULE,
+ ActivationWrapper,
)
from torch.distributed.algorithms._comm_hooks import default_hooks, LOW_PRECISION_HOOKS
from torch.distributed.distributed_c10d import _get_default_group
@@ -1634,6 +1636,10 @@
"""
Returns the wrapped module (like :class:`DistributedDataParallel`).
"""
+ # FSDP's `.module` must refer to the innermost wrapped module when
+ # composing with other module wrappers in order for state dict to work
+ if isinstance(self._fsdp_wrapped_module, ActivationWrapper):
+ return getattr(self._fsdp_wrapped_module, _CHECKPOINT_WRAPPED_MODULE)
return self._fsdp_wrapped_module
@property
@@ -1884,44 +1890,6 @@
# to non-root instances
inconsistent_limit_all_gathers = False
for fsdp_module in self.fsdp_modules(self):
- if not fsdp_module._use_orig_params and fsdp_module._has_params:
- # Check if the wrapped module changed after construction
- # (e.g. applying the activation checkpointing wrapper) and
- # if so, de-register the `FlatParameter` from the old
- # wrapped module and register it to the new wrapped module
- # NOTE: The `FlatParameter`'s FQN metadata is not updated, so
- # any added wrappers must clean their prefixes from FQNs.
- flat_param = fsdp_module._handles[0].flat_param
- target_submodule = None
- target_name = None
- for submodule in fsdp_module.modules():
- for param_name, param in submodule._parameters.items():
- if flat_param is param: # found registered `FlatParameter`
- target_submodule = submodule
- target_name = param_name
- break
- if target_submodule is not None:
- break
- if (
- target_submodule is not None
- and target_submodule is not fsdp_module.module
- ):
- assert target_name is not None
- if fsdp_module._debug_level == dist.DebugLevel.DETAIL:
- warnings.warn(
- "The FSDP wrapped module changed from "
- f"{target_submodule} to {fsdp_module.module} on "
- f"rank {fsdp_module.rank}. {fsdp_module}"
- )
- target_submodule._parameters.pop(target_name) # de-register
- fsdp_module._register_flat_param() # re-register
- elif target_submodule is None:
- raise RuntimeError(
- "Either the FSDP wrapped module was removed from "
- "the model or its `FlatParameter` was manually "
- f"de-registered on rank {fsdp_module.rank}. Both of "
- f"these are invalid behavior. {fsdp_module}"
- )
if fsdp_module is not self:
# Relax the assert for non-root FSDP instances in case the
# nested initialized module is wrapped again in FSDP later (e.g.