[PT-D][Tensor Parallel] Add more test cases when we use use_orig_params for FSDP wrapping (#89779)
Differential Revision: [D41600656](https://our.internmc.facebook.com/intern/diff/D41600656)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89779
Approved by: https://github.com/wanchaol
diff --git a/test/distributed/_tensor/parallel/test_2d_parallel.py b/test/distributed/_tensor/parallel/test_2d_parallel.py
index ea41d53..da6d1f5 100644
--- a/test/distributed/_tensor/parallel/test_2d_parallel.py
+++ b/test/distributed/_tensor/parallel/test_2d_parallel.py
@@ -10,10 +10,8 @@
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
from torch.distributed._tensor import (
- distribute_tensor,
DeviceMesh,
DTensor as DT,
- Shard,
Replicate,
)
from torch.distributed._tensor.parallel import (
@@ -52,67 +50,26 @@
return x
-def _aggregate_local_tensor(module: torch.nn.Module) -> torch.nn.Module:
- def hook_func(_module, _input, output):
- if isinstance(output, DT):
- replica_placement = [Replicate()]
- return output.redistribute(
- output.device_mesh, replica_placement
- ).to_local()
-
- module.register_forward_hook(hook_func)
- return module
-
-
-def _replicate_input_tensor(
- module: torch.nn.Module, device_mesh, replica_placement
-) -> torch.nn.Module:
- def hook_func(_, input):
- if not isinstance(input[0], DT):
- return DT.from_local(
- input[0], device_mesh, replica_placement, run_check=False
- )
-
- module.register_forward_pre_hook(hook_func)
- return module
-
-
-def shard_module(m, pg):
- start_idx = distributed_c10d.get_global_rank(pg, 0)
- device_mesh = DeviceMesh(
- "cuda", list(range(start_idx, start_idx + pg.size())), dim_groups=[pg]
- )
- col_wise_sharding = [Shard(0)]
- row_wise_sharding = [Shard(1)]
- replicate = [Replicate()]
- m.net1.weight = torch.nn.Parameter(
- distribute_tensor(m.net1.weight, device_mesh, col_wise_sharding),
- )
- m.net2.weight = torch.nn.Parameter(
- distribute_tensor(m.net2.weight, device_mesh, row_wise_sharding)
- )
- m.net1.bias = torch.nn.Parameter(
- distribute_tensor(m.net1.bias, device_mesh, col_wise_sharding)
- )
- m.net2.bias = torch.nn.Parameter(
- distribute_tensor(m.net2.bias, device_mesh, replicate)
- )
- m = _replicate_input_tensor(m, device_mesh, replicate)
- m.net2 = _aggregate_local_tensor(m.net2)
-
-
-def _shard_wrap_module(module, module_shard, fsdp_wrap, mesh_2d, fsdp_pg):
+def _distribute_and_fsdp_wrap_module(
+ module, module_shard, mesh_2d, fsdp_pg, use_orig_params, fsdp_nested
+):
if module_shard:
- parallelize_module(module, mesh_2d, PairwiseParallel(), tp_mesh_dim=1)
+ module = parallelize_module(module, mesh_2d, PairwiseParallel(), tp_mesh_dim=1)
+ pg = fsdp_pg if module_shard else distributed_c10d._get_default_group()
- if fsdp_wrap and module_shard:
- return FSDP(module, process_group=fsdp_pg)
- if fsdp_wrap:
- return FSDP(module, process_group=distributed_c10d._get_default_group())
- return module
+ if fsdp_nested:
+ module.net1 = FSDP(
+ module.net1, process_group=pg, use_orig_params=use_orig_params
+ )
+ module.net2 = FSDP(
+ module.net2, process_group=pg, use_orig_params=use_orig_params
+ )
+ return FSDP(
+ module, process_group=pg, use_orig_params=use_orig_params
+ )
-def init_model(model_parallel_size=TP_DEGREE):
+def init_model(model_parallel_size=TP_DEGREE, use_orig_params=False, fsdp_nested=False):
rank = dist.get_rank()
torch.cuda.set_device(rank)
world_size = dist.get_world_size()
@@ -128,7 +85,9 @@
fsdp_pg = twod_mesh.get_dim_groups()[0]
# Create Input
- model = _shard_wrap_module(model, True, True, twod_mesh, fsdp_pg)
+ model = _distribute_and_fsdp_wrap_module(
+ model, True, twod_mesh, fsdp_pg, use_orig_params, fsdp_nested
+ )
return model, fsdp_pg
@@ -182,19 +141,50 @@
is_nested_tensor(optim_state["state"]["net3.bias"]["exp_avg"])
)
- @with_comms
- @skip_if_lt_x_gpu(4)
- def test_2d_fsdp_integration_correctness(self) -> None:
+ def _compare_params(self, m1, m2):
+ with FSDP.summon_full_params(m1):
+ with FSDP.summon_full_params(m2):
+ for n_p1, n_p2 in zip(m1.named_parameters(), m2.named_parameters()):
+ p1 = n_p1[1]
+ p2 = n_p2[1]
+ self.assertEqual(n_p1[0], n_p2[0])
+ name = n_p1[0]
+ if name == "net2.bias" and self.rank != 0:
+ continue
+ if type(p2) is DT:
+ p2 = p2.redistribute(
+ p2.device_mesh, [Replicate()]
+ ).to_local()
+ self.assertTrue(torch.allclose(p1, p2), f"{p1} vs {p2}")
+
+ def _test_2d_e2e_flow(self, use_orig_params=False, fsdp_nested=False, multi_param_group=False) -> None:
if not is_available():
self.skipTest("FSDP 2d parallel integration not available")
torch.manual_seed(0)
model = SimpleModel().cuda(self.rank)
- model = FSDP(model)
+ model = FSDP(model, use_orig_params=use_orig_params)
torch.manual_seed(0)
- model_2d, dp_pg = init_model()
+ model_2d, dp_pg = init_model(use_orig_params=use_orig_params, fsdp_nested=fsdp_nested)
+ # Check named parameters are returning the same name at least.
+ param_names_2d = [name for name, _ in model_2d.named_parameters()]
+ for name, _ in model.named_parameters():
+ self.assertTrue(name in param_names_2d)
+ self._compare_params(model, model_2d)
- optim = torch.optim.Adam(model.parameters(), lr=0.0001)
- optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.0001)
+ if multi_param_group and use_orig_params:
+ param_group = [
+ {"params": model.net1.parameters(), "lr": 0.02},
+ {"params": model.net2.parameters(), "lr": 0.15},
+ ]
+ optim = torch.optim.Adam(param_group, lr=0.01)
+ param_group = [
+ {"params": model_2d.net1.parameters(), "lr": 0.02},
+ {"params": model_2d.net2.parameters(), "lr": 0.15},
+ ]
+ optim_2d = torch.optim.Adam(param_group, lr=0.01)
+ else:
+ optim = torch.optim.Adam(model.parameters(), lr=0.01)
+ optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.01)
for i in range(5):
# Ensure all input across TP ranks are same.
@@ -209,6 +199,29 @@
optim_2d.step()
self.assertEqual(model(input), model_2d(input))
+ # Ensure all params are still the same after optimizer update.
+ self._compare_params(model, model_2d)
+
+ @with_comms
+ @skip_if_lt_x_gpu(4)
+ def test_2d_fsdp_integration_correctness(self) -> None:
+ self._test_2d_e2e_flow()
+
+ @with_comms
+ @skip_if_lt_x_gpu(4)
+ def test_2d_fsdp_integration_use_orig_params(self) -> None:
+ self._test_2d_e2e_flow(use_orig_params=True)
+
+ @with_comms
+ @skip_if_lt_x_gpu(4)
+ def test_2d_fsdp_integration_fsdp_nested(self) -> None:
+ self._test_2d_e2e_flow(fsdp_nested=True)
+
+ @with_comms
+ @skip_if_lt_x_gpu(4)
+ def test_2d_fsdp_integration_fsdp_nested_param_groups(self) -> None:
+ self._test_2d_e2e_flow(fsdp_nested=True, use_orig_params=True, multi_param_group=True)
+
if __name__ == "__main__":
run_tests()
diff --git a/torch/distributed/_tensor/dispatch.py b/torch/distributed/_tensor/dispatch.py
index 8c9e5a2..38ea056 100644
--- a/torch/distributed/_tensor/dispatch.py
+++ b/torch/distributed/_tensor/dispatch.py
@@ -151,7 +151,9 @@
_CURRENT_DECOMPOSITION_TABLE: Dict[
Callable[..., object], Callable[..., object]
-] = {torch.ops.aten._reshape_alias.default: _reshape_alias}
+] = {
+ torch.ops.aten._reshape_alias.default: _reshape_alias,
+}
def propagate_input_sharding(