[FSDP2] reset FSDPParam.sharded_param in lazy_init (#132954)

motivated by FSDP2 + DoRA https://github.com/pytorch/pytorch/issues/132721

after meta init, we need a user-defined function to move DoRALinear.magnitude from device=meta to device=cuda
The problem is how to trigger reset_sharded_param or _apply to update FSDPParam. Otherwise lazy_init complains that DoRALinear.magnitude are still on device=meta

credit to @awgu for chasing after DDP lazy_init to unblock the PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132954
Approved by: https://github.com/awgu
ghstack dependencies: #133059
diff --git a/test/distributed/_composable/fsdp/test_fully_shard_init.py b/test/distributed/_composable/fsdp/test_fully_shard_init.py
index 8c02196..a2bf15d 100644
--- a/test/distributed/_composable/fsdp/test_fully_shard_init.py
+++ b/test/distributed/_composable/fsdp/test_fully_shard_init.py
@@ -536,6 +536,41 @@
         with self.assertRaisesRegex(RuntimeError, regex):
             root_state._lazy_init()
 
+    @unittest.skipIf(not TEST_CUDA, "no cuda")
+    def test_reset_sharded_param_in_lazy_init(self):
+        class MyModel(nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.layer1 = nn.Linear(3, 3, bias=False)
+                self.layer2 = nn.Linear(3, 3, bias=False)
+                self.weight_norm = nn.Parameter(torch.empty(3))
+
+            def init_weight_norm(self):
+                with torch.no_grad():
+                    weight_norm = torch.linalg.norm(
+                        self.layer1.weight, dim=1
+                    ) + torch.linalg.norm(self.layer2.weight, dim=1)
+                model.weight_norm = nn.Parameter(weight_norm)
+
+            def forward(self, inp: torch.Tensor) -> torch.Tensor:
+                out = self.layer1(inp)
+                out = self.layer2(out)
+                return out.sum() + self.weight_norm.sum()
+
+        with torch.device("meta"):
+            model = MyModel()
+        fully_shard(model.layer1)
+        fully_shard(model.layer2)
+        fully_shard(model)
+
+        model.layer1.to_empty(device="cuda")
+        model.layer2.to_empty(device="cuda")
+        model.init_weight_norm()
+
+        inp = torch.randn(3, 3, device="cuda")
+        loss = model(inp).sum()
+        loss.backward()
+
 
 class TestFullyShardMetaDeviceInit(FSDPTestMultiThread):
     @property
diff --git a/torch/distributed/_composable/fsdp/_fsdp_param_group.py b/torch/distributed/_composable/fsdp/_fsdp_param_group.py
index dbbcf6e..3c2f639 100644
--- a/torch/distributed/_composable/fsdp/_fsdp_param_group.py
+++ b/torch/distributed/_composable/fsdp/_fsdp_param_group.py
@@ -131,6 +131,9 @@
         # Group's sharded state always matches its parameters' sharded states
         self._sharded_state = ShardedState.SHARDED
         self._module_fqn: Optional[str] = None  # prefixed from root module
+        # Only consider resetting sharded parameters once in lazy init since it
+        # can incur nontrivial overhead to reset them
+        self._reset_sharded_params: bool = False
 
         # - Hook state
         self._module_to_pre_save_state_dict_hook_handle: _ModuleToHandleDict = {}
@@ -191,6 +194,13 @@
 
     def lazy_init(self):
         # Lazy init should be idempotent
+        # Users may change or register parameters after construction time.
+        # For example, DoRA (https://arxiv.org/abs/2402.09353) initializes linear magnitudes based on
+        # other parameters (e.g. loaded from the state dict).
+        if self.is_sharded and not self._reset_sharded_params:
+            for fsdp_param in self.fsdp_params:
+                fsdp_param.reset_sharded_param()
+            self._reset_sharded_params = True
         param_names_on_meta = [
             fsdp_param._param_fqn
             for fsdp_param in self.fsdp_params