[PT-D][Tensor Parallel] Add more test cases when we use use_orig_params for FSDP wrapping (#89779)

Differential Revision: [D41600656](https://our.internmc.facebook.com/intern/diff/D41600656)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89779
Approved by: https://github.com/wanchaol
diff --git a/test/distributed/_tensor/parallel/test_2d_parallel.py b/test/distributed/_tensor/parallel/test_2d_parallel.py
index ea41d53..da6d1f5 100644
--- a/test/distributed/_tensor/parallel/test_2d_parallel.py
+++ b/test/distributed/_tensor/parallel/test_2d_parallel.py
@@ -10,10 +10,8 @@
 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
 from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
 from torch.distributed._tensor import (
-    distribute_tensor,
     DeviceMesh,
     DTensor as DT,
-    Shard,
     Replicate,
 )
 from torch.distributed._tensor.parallel import (
@@ -52,67 +50,26 @@
         return x
 
 
-def _aggregate_local_tensor(module: torch.nn.Module) -> torch.nn.Module:
-    def hook_func(_module, _input, output):
-        if isinstance(output, DT):
-            replica_placement = [Replicate()]
-            return output.redistribute(
-                output.device_mesh, replica_placement
-            ).to_local()
-
-    module.register_forward_hook(hook_func)
-    return module
-
-
-def _replicate_input_tensor(
-    module: torch.nn.Module, device_mesh, replica_placement
-) -> torch.nn.Module:
-    def hook_func(_, input):
-        if not isinstance(input[0], DT):
-            return DT.from_local(
-                input[0], device_mesh, replica_placement, run_check=False
-            )
-
-    module.register_forward_pre_hook(hook_func)
-    return module
-
-
-def shard_module(m, pg):
-    start_idx = distributed_c10d.get_global_rank(pg, 0)
-    device_mesh = DeviceMesh(
-        "cuda", list(range(start_idx, start_idx + pg.size())), dim_groups=[pg]
-    )
-    col_wise_sharding = [Shard(0)]
-    row_wise_sharding = [Shard(1)]
-    replicate = [Replicate()]
-    m.net1.weight = torch.nn.Parameter(
-        distribute_tensor(m.net1.weight, device_mesh, col_wise_sharding),
-    )
-    m.net2.weight = torch.nn.Parameter(
-        distribute_tensor(m.net2.weight, device_mesh, row_wise_sharding)
-    )
-    m.net1.bias = torch.nn.Parameter(
-        distribute_tensor(m.net1.bias, device_mesh, col_wise_sharding)
-    )
-    m.net2.bias = torch.nn.Parameter(
-        distribute_tensor(m.net2.bias, device_mesh, replicate)
-    )
-    m = _replicate_input_tensor(m, device_mesh, replicate)
-    m.net2 = _aggregate_local_tensor(m.net2)
-
-
-def _shard_wrap_module(module, module_shard, fsdp_wrap, mesh_2d, fsdp_pg):
+def _distribute_and_fsdp_wrap_module(
+    module, module_shard, mesh_2d, fsdp_pg, use_orig_params, fsdp_nested
+):
     if module_shard:
-        parallelize_module(module, mesh_2d, PairwiseParallel(), tp_mesh_dim=1)
+        module = parallelize_module(module, mesh_2d, PairwiseParallel(), tp_mesh_dim=1)
+    pg = fsdp_pg if module_shard else distributed_c10d._get_default_group()
 
-    if fsdp_wrap and module_shard:
-        return FSDP(module, process_group=fsdp_pg)
-    if fsdp_wrap:
-        return FSDP(module, process_group=distributed_c10d._get_default_group())
-    return module
+    if fsdp_nested:
+        module.net1 = FSDP(
+            module.net1, process_group=pg, use_orig_params=use_orig_params
+        )
+        module.net2 = FSDP(
+            module.net2, process_group=pg, use_orig_params=use_orig_params
+        )
+    return FSDP(
+        module, process_group=pg, use_orig_params=use_orig_params
+    )
 
 
-def init_model(model_parallel_size=TP_DEGREE):
+def init_model(model_parallel_size=TP_DEGREE, use_orig_params=False, fsdp_nested=False):
     rank = dist.get_rank()
     torch.cuda.set_device(rank)
     world_size = dist.get_world_size()
@@ -128,7 +85,9 @@
     fsdp_pg = twod_mesh.get_dim_groups()[0]
 
     # Create Input
-    model = _shard_wrap_module(model, True, True, twod_mesh, fsdp_pg)
+    model = _distribute_and_fsdp_wrap_module(
+        model, True, twod_mesh, fsdp_pg, use_orig_params, fsdp_nested
+    )
     return model, fsdp_pg
 
 
@@ -182,19 +141,50 @@
             is_nested_tensor(optim_state["state"]["net3.bias"]["exp_avg"])
         )
 
-    @with_comms
-    @skip_if_lt_x_gpu(4)
-    def test_2d_fsdp_integration_correctness(self) -> None:
+    def _compare_params(self, m1, m2):
+        with FSDP.summon_full_params(m1):
+            with FSDP.summon_full_params(m2):
+                for n_p1, n_p2 in zip(m1.named_parameters(), m2.named_parameters()):
+                    p1 = n_p1[1]
+                    p2 = n_p2[1]
+                    self.assertEqual(n_p1[0], n_p2[0])
+                    name = n_p1[0]
+                    if name == "net2.bias" and self.rank != 0:
+                        continue
+                    if type(p2) is DT:
+                        p2 = p2.redistribute(
+                            p2.device_mesh, [Replicate()]
+                        ).to_local()
+                    self.assertTrue(torch.allclose(p1, p2), f"{p1} vs {p2}")
+
+    def _test_2d_e2e_flow(self, use_orig_params=False, fsdp_nested=False, multi_param_group=False) -> None:
         if not is_available():
             self.skipTest("FSDP 2d parallel integration not available")
         torch.manual_seed(0)
         model = SimpleModel().cuda(self.rank)
-        model = FSDP(model)
+        model = FSDP(model, use_orig_params=use_orig_params)
         torch.manual_seed(0)
-        model_2d, dp_pg = init_model()
+        model_2d, dp_pg = init_model(use_orig_params=use_orig_params, fsdp_nested=fsdp_nested)
+        # Check named parameters are returning the same name at least.
+        param_names_2d = [name for name, _ in model_2d.named_parameters()]
+        for name, _ in model.named_parameters():
+            self.assertTrue(name in param_names_2d)
+        self._compare_params(model, model_2d)
 
-        optim = torch.optim.Adam(model.parameters(), lr=0.0001)
-        optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.0001)
+        if multi_param_group and use_orig_params:
+            param_group = [
+                {"params": model.net1.parameters(), "lr": 0.02},
+                {"params": model.net2.parameters(), "lr": 0.15},
+            ]
+            optim = torch.optim.Adam(param_group, lr=0.01)
+            param_group = [
+                {"params": model_2d.net1.parameters(), "lr": 0.02},
+                {"params": model_2d.net2.parameters(), "lr": 0.15},
+            ]
+            optim_2d = torch.optim.Adam(param_group, lr=0.01)
+        else:
+            optim = torch.optim.Adam(model.parameters(), lr=0.01)
+            optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.01)
 
         for i in range(5):
             # Ensure all input across TP ranks are same.
@@ -209,6 +199,29 @@
             optim_2d.step()
             self.assertEqual(model(input), model_2d(input))
 
+        # Ensure all params are still the same after optimizer update.
+        self._compare_params(model, model_2d)
+
+    @with_comms
+    @skip_if_lt_x_gpu(4)
+    def test_2d_fsdp_integration_correctness(self) -> None:
+        self._test_2d_e2e_flow()
+
+    @with_comms
+    @skip_if_lt_x_gpu(4)
+    def test_2d_fsdp_integration_use_orig_params(self) -> None:
+        self._test_2d_e2e_flow(use_orig_params=True)
+
+    @with_comms
+    @skip_if_lt_x_gpu(4)
+    def test_2d_fsdp_integration_fsdp_nested(self) -> None:
+        self._test_2d_e2e_flow(fsdp_nested=True)
+
+    @with_comms
+    @skip_if_lt_x_gpu(4)
+    def test_2d_fsdp_integration_fsdp_nested_param_groups(self) -> None:
+        self._test_2d_e2e_flow(fsdp_nested=True, use_orig_params=True, multi_param_group=True)
+
 
 if __name__ == "__main__":
     run_tests()
diff --git a/torch/distributed/_tensor/dispatch.py b/torch/distributed/_tensor/dispatch.py
index 8c9e5a2..38ea056 100644
--- a/torch/distributed/_tensor/dispatch.py
+++ b/torch/distributed/_tensor/dispatch.py
@@ -151,7 +151,9 @@
 
 _CURRENT_DECOMPOSITION_TABLE: Dict[
     Callable[..., object], Callable[..., object]
-] = {torch.ops.aten._reshape_alias.default: _reshape_alias}
+] = {
+    torch.ops.aten._reshape_alias.default: _reshape_alias,
+}
 
 
 def propagate_input_sharding(