[dynamo] enable 2d torch.compile test (#107473)
This PR adds 2d parallel torch.compile test on a simple MLP model and
test that the dynamo changes works, once @bdhirsh aot autograd enablement
done we can switch this test to test the e2e torch.compile workflow
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107473
Approved by: https://github.com/fduwjj
ghstack dependencies: #107472
diff --git a/.ci/pytorch/multigpu-test.sh b/.ci/pytorch/multigpu-test.sh
index 23a27af..f3bf768 100755
--- a/.ci/pytorch/multigpu-test.sh
+++ b/.ci/pytorch/multigpu-test.sh
@@ -36,8 +36,9 @@
# DTensor tests
-time python test/run_test.py --verbose -i distributed/_tensor/test_device_mesh.py
-time python test/run_test.py --verbose -i distributed/_tensor/test_random_ops.py
+time python test/run_test.py --verbose -i distributed/_tensor/test_device_mesh
+time python test/run_test.py --verbose -i distributed/_tensor/test_random_ops
+time python test/run_test.py --verbose -i distributed/_tensor/test_dtensor_compile
# DTensor/TP tests
time python test/run_test.py --verbose -i distributed/tensor/parallel/test_ddp_2d_parallel
diff --git a/test/distributed/_tensor/test_dtensor_compile.py b/test/distributed/_tensor/test_dtensor_compile.py
new file mode 100644
index 0000000..7ff7795
--- /dev/null
+++ b/test/distributed/_tensor/test_dtensor_compile.py
@@ -0,0 +1,83 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# Owner(s): ["oncall: distributed"]
+
+import copy
+
+import torch
+import torch._dynamo
+import torch.nn as nn
+from torch.distributed._tensor import DeviceMesh
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.tensor.parallel import PairwiseParallel, parallelize_module
+from torch.distributed.tensor.parallel.fsdp import enable_2d_with_fsdp
+from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
+from torch.testing._internal.common_utils import run_tests
+from torch.testing._internal.distributed._tensor.common_dtensor import (
+ DTensorTestBase,
+ MLPModule,
+ with_comms,
+)
+
+
+class SimpleModel(nn.Module):
+ def __init__(self, device):
+ super().__init__()
+ self.mlp_0 = MLPModule(device)
+ self.mlp_1 = MLPModule(device)
+
+ def forward(self, input):
+ return self.mlp_1(self.mlp_0(input))
+
+
+class TestDTensorCompile(DTensorTestBase):
+ @property
+ def world_size(self):
+ return 4
+
+ @with_comms
+ @skip_if_lt_x_gpu(4)
+ def test_2d_fsdp_tp_compile(self):
+ data_parallel_size = 2
+ model = SimpleModel(self.device_type)
+ model_copy = copy.deepcopy(model)
+ enable_2d_with_fsdp()
+
+ # 2-D mesh is [dp, tp]
+ twod_mesh = DeviceMesh(
+ device_type="cuda",
+ mesh=torch.arange(0, self.world_size).view(data_parallel_size, -1),
+ )
+
+ fsdp_pg = twod_mesh.get_dim_groups()[0]
+
+ inp = torch.rand(20, 10, device=self.device_type)
+ tp_model = parallelize_module(
+ model, twod_mesh, PairwiseParallel(), tp_mesh_dim=1
+ )
+ eager_2d = FSDP(
+ tp_model, process_group=fsdp_pg, device_id=self.rank, use_orig_params=True
+ )
+ out = eager_2d(inp)
+ # TODO: once aot autograd support is ready we can just use default backend
+ tp_model2 = parallelize_module(
+ model_copy, twod_mesh, PairwiseParallel(), tp_mesh_dim=1
+ )
+ compiled_tp = torch.compile(tp_model2, backend="eager", fullgraph=True)
+
+ # TODO: now we first apply torch compile on tp model then use fsdp to wrap it, ideally
+ # we should apply torch.compile after fsdp wrap, but the current graph break approach
+ # have some issues with the tensor subclass compilation, need to dig into this later
+ compiled_2d = FSDP(
+ compiled_tp,
+ process_group=fsdp_pg,
+ device_id=self.rank,
+ use_orig_params=True,
+ )
+
+ compiled_output = compiled_2d(inp)
+
+ self.assertEqual(out, compiled_output)
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py
index df48ba8..6642e9c 100644
--- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py
+++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py
@@ -56,7 +56,7 @@
torch.manual_seed(5)
self.net1 = torch.nn.Linear(10, 16, device=device)
self.relu = torch.nn.ReLU()
- self.net2 = torch.nn.Linear(16, 12, device=device)
+ self.net2 = torch.nn.Linear(16, 10, device=device)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))