[FSDP2] reset FSDPParam.sharded_param in lazy_init (#132954)
motivated by FSDP2 + DoRA https://github.com/pytorch/pytorch/issues/132721
after meta init, we need a user-defined function to move DoRALinear.magnitude from device=meta to device=cuda
The problem is how to trigger reset_sharded_param or _apply to update FSDPParam. Otherwise lazy_init complains that DoRALinear.magnitude are still on device=meta
credit to @awgu for chasing after DDP lazy_init to unblock the PR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132954
Approved by: https://github.com/awgu
ghstack dependencies: #133059
diff --git a/test/distributed/_composable/fsdp/test_fully_shard_init.py b/test/distributed/_composable/fsdp/test_fully_shard_init.py
index 8c02196..a2bf15d 100644
--- a/test/distributed/_composable/fsdp/test_fully_shard_init.py
+++ b/test/distributed/_composable/fsdp/test_fully_shard_init.py
@@ -536,6 +536,41 @@
with self.assertRaisesRegex(RuntimeError, regex):
root_state._lazy_init()
+ @unittest.skipIf(not TEST_CUDA, "no cuda")
+ def test_reset_sharded_param_in_lazy_init(self):
+ class MyModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.layer1 = nn.Linear(3, 3, bias=False)
+ self.layer2 = nn.Linear(3, 3, bias=False)
+ self.weight_norm = nn.Parameter(torch.empty(3))
+
+ def init_weight_norm(self):
+ with torch.no_grad():
+ weight_norm = torch.linalg.norm(
+ self.layer1.weight, dim=1
+ ) + torch.linalg.norm(self.layer2.weight, dim=1)
+ model.weight_norm = nn.Parameter(weight_norm)
+
+ def forward(self, inp: torch.Tensor) -> torch.Tensor:
+ out = self.layer1(inp)
+ out = self.layer2(out)
+ return out.sum() + self.weight_norm.sum()
+
+ with torch.device("meta"):
+ model = MyModel()
+ fully_shard(model.layer1)
+ fully_shard(model.layer2)
+ fully_shard(model)
+
+ model.layer1.to_empty(device="cuda")
+ model.layer2.to_empty(device="cuda")
+ model.init_weight_norm()
+
+ inp = torch.randn(3, 3, device="cuda")
+ loss = model(inp).sum()
+ loss.backward()
+
class TestFullyShardMetaDeviceInit(FSDPTestMultiThread):
@property
diff --git a/torch/distributed/_composable/fsdp/_fsdp_param_group.py b/torch/distributed/_composable/fsdp/_fsdp_param_group.py
index dbbcf6e..3c2f639 100644
--- a/torch/distributed/_composable/fsdp/_fsdp_param_group.py
+++ b/torch/distributed/_composable/fsdp/_fsdp_param_group.py
@@ -131,6 +131,9 @@
# Group's sharded state always matches its parameters' sharded states
self._sharded_state = ShardedState.SHARDED
self._module_fqn: Optional[str] = None # prefixed from root module
+ # Only consider resetting sharded parameters once in lazy init since it
+ # can incur nontrivial overhead to reset them
+ self._reset_sharded_params: bool = False
# - Hook state
self._module_to_pre_save_state_dict_hook_handle: _ModuleToHandleDict = {}
@@ -191,6 +194,13 @@
def lazy_init(self):
# Lazy init should be idempotent
+ # Users may change or register parameters after construction time.
+ # For example, DoRA (https://arxiv.org/abs/2402.09353) initializes linear magnitudes based on
+ # other parameters (e.g. loaded from the state dict).
+ if self.is_sharded and not self._reset_sharded_params:
+ for fsdp_param in self.fsdp_params:
+ fsdp_param.reset_sharded_param()
+ self._reset_sharded_params = True
param_names_on_meta = [
fsdp_param._param_fqn
for fsdp_param in self.fsdp_params