blob: 8c7518d1a341cb29004ce055df2aaddbeed7fcf6 [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
import sys
import torch
import torch.distributed as dist
import torch.nn as nn
import unittest
import torch.distributed._functional_collectives as funcol
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.testing import FileCheck
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed._tensor import DeviceMesh
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)
from torch.distributed.tensor.parallel import (
PairwiseParallel,
SequenceParallel,
parallelize_module,
)
from torch.distributed.tensor.parallel.fsdp import enable_2d_with_fsdp
from torch.testing._internal.distributed._tensor.common_dtensor import (
MLPModule,
)
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
HAS_CUDA = torch.cuda.is_available()
class TestFakePG(TestCase):
def tearDown(self):
super().tearDown()
dist.destroy_process_group()
def test_all_reduce(self):
store = FakeStore()
dist.init_process_group(
backend="fake", rank=1, world_size=2, store=store
)
output = torch.ones(3, 3) * dist.get_rank()
dist.all_reduce(output)
self.assertEqual(tuple(output.shape), (3, 3))
def test_allgather(self):
store = FakeStore()
dist.init_process_group(
backend="fake", rank=1, world_size=2, store=store
)
input_tensor = torch.ones(3, 3) * dist.get_rank()
output_tensors = [torch.empty_like(input_tensor) for _ in range(2)]
dist.all_gather(output_tensors, input_tensor)
for _, out_tensor in enumerate(output_tensors):
self.assertEqual(tuple(out_tensor.shape), (3, 3))
def test_reduce_scatter(self):
store = FakeStore()
dist.init_process_group(
backend="fake", rank=1, world_size=2, store=store
)
to_reduce_scatter = [torch.ones(3, 3) * rank for rank in range(2)]
output_tensor = torch.empty(3, 3)
dist.reduce_scatter(output_tensor, to_reduce_scatter)
self.assertEqual(tuple(output_tensor.shape), (3, 3))
@unittest.skipIf(not HAS_CUDA, 'No CUDA')
def test_construct_fsdp(self):
store = FakeStore()
dist.init_process_group(
backend="fake", rank=0, world_size=2, store=store
)
FSDP(nn.Linear(2, 3, device='cuda'))
@unittest.skipIf(not HAS_CUDA, 'No CUDA')
def test_fsdp_fake_e2e(self):
store = dist.HashStore()
dist.init_process_group(
backend="fake", rank=0, world_size=2, store=store
)
my_module = nn.Sequential(
nn.Linear(2, 3, device='cuda'),
nn.ReLU(),
nn.Linear(3, 2, device='cuda'),
)
sharded_module = FSDP(my_module, use_orig_params=True)
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
input = torch.randn(2, 2)
x = sharded_module(input)
loss = x.sum()
loss.backward()
optim.step()
@unittest.skipIf(not HAS_CUDA, 'No CUDA')
def test_fake_pg_tracing(self):
store = dist.HashStore()
dist.init_process_group(
backend="fake", rank=0, world_size=2, store=store
)
default_pg = dist.distributed_c10d._get_default_group()
def allgather_fn(tensor):
return funcol.all_gather_tensor(tensor, 0, default_pg)
gm = make_fx(allgather_fn)(torch.randn(2, 2, device='cuda'))
FileCheck().check("all_gather").check("wait_tensor").run(str(gm.graph))
def test_broadcast(self):
store = FakeStore()
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
# src == rank
output = torch.ones(3, 3)
dist.broadcast(output, src=0)
self.assertEqual(tuple(output.shape), (3, 3))
# src != rank
output = torch.ones(3, 3)
dist.broadcast(output, src=1)
self.assertEqual(tuple(output.shape), (3, 3))
def test_scatter(self):
store = FakeStore()
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
# src == rank
output = torch.ones(3, 3)
to_scatter = [torch.ones(3, 3) * rank for rank in range(2)]
dist.scatter(output, to_scatter)
self.assertEqual(tuple(output.shape), (3, 3))
# src != rank
output = torch.ones(3, 3)
dist.scatter(output, None, src=1)
self.assertEqual(tuple(output.shape), (3, 3))
def test_alltoall(self):
store = FakeStore()
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
output_list = [torch.ones(3, 3) for _ in range(2)]
input_list = [torch.ones(3, 3) for _ in range(2)]
dist.all_to_all(output_list, input_list)
self.assertEqual(len(output_list), 2)
for output in output_list:
self.assertEqual(tuple(output.shape), (3, 3))
def test_alltoall_base(self):
store = FakeStore()
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
out_tensor = torch.ones(3, 3)
in_tensor = torch.ones(3, 3)
output_split = [1, 1]
input_split = [1, 1]
dist.all_to_all_single(out_tensor, in_tensor, output_split, input_split)
self.assertEqual(tuple(out_tensor.shape), (3, 3))
def test_send(self):
store = FakeStore()
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
tensor = torch.ones(3, 3)
dist.send(tensor, 1)
self.assertEqual(tuple(tensor.shape), (3, 3))
def test_recv(self):
store = FakeStore()
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
output = torch.ones(3, 3)
dist.recv(output, 1)
self.assertEqual(tuple(output.shape), (3, 3))
@unittest.skipIf(not HAS_CUDA or not enable_2d_with_fsdp(), "No CUDA or TP+FSDP")
def test_fsdp_tp_fake_e2e(self):
world_size = 4
tp_size = 2
store = dist.HashStore()
dist.init_process_group(backend="fake", rank=0, world_size=world_size, store=store)
device_mesh = DeviceMesh(
"cuda", torch.arange(0, world_size).view(-1, tp_size)
)
for parallel_style in [SequenceParallel(), PairwiseParallel()]:
my_module = parallelize_module(MLPModule(device="cuda"), device_mesh, parallel_style, tp_mesh_dim=1)
sharded_module = FSDP(my_module, use_orig_params=True)
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
for i in range(10):
dp_rank = dist.get_rank()
torch.manual_seed(i + dp_rank)
input = torch.randn(20, 10).cuda(dist.get_rank())
x = sharded_module(input)
loss = x.sum()
loss.backward()
optim.step()
if __name__ == "__main__":
run_tests()