| # Owner(s): ["oncall: distributed"] |
| |
| import os |
| import sys |
| import tempfile |
| import threading |
| from functools import partial, wraps |
| |
| import torch |
| import torch.distributed as dist |
| import torch.distributed._hooks as dhooks |
| |
| if not dist.is_available(): |
| print("torch.distributed not available, skipping tests", file=sys.stderr) |
| sys.exit(0) |
| |
| |
| from torch.testing._internal.common_distributed import ( |
| MultiProcessTestCase, |
| skip_if_lt_x_gpu, |
| ) |
| |
| from torch.testing._internal.common_utils import run_tests, TestCase |
| |
| |
| class PgHooks(MultiProcessTestCase): |
| @property |
| def world_size(self) -> int: |
| return 4 |
| |
| def setUp(self) -> None: |
| super().setUp() |
| self._spawn_processes() |
| |
| def tearDown(self): |
| super().tearDown() |
| try: |
| os.remove(self.file_name) |
| except OSError: |
| pass |
| |
| def test_pg_hook(self): |
| pgs = [] |
| |
| def pg_hook(pg, pg_name): |
| pgs.append((pg, pg_name)) |
| |
| dhooks.register_process_group_hook(pg_hook) |
| dist.init_process_group( |
| backend="gloo", |
| rank=self.rank, |
| world_size=self.world_size, |
| store=dist.FileStore(self.file_name, self.world_size), |
| ) |
| self.assertEqual(len(pgs), 1) |
| self.assertEqual(pgs[0][0], dist.group.WORLD) |
| |
| # create two partial world PGs |
| pg0 = dist.new_group(ranks=[0, 1]) |
| pg1 = dist.new_group(ranks=[2, 3]) |
| |
| # Each rank only observe two PGs being created: the default PG and one covering its ranks |
| # We don't emit events for PG creation if the current rank doesn't belong to it. |
| # For example, say you're rank 1, you'll get an event for pg0 but not pg1 even though the API contact |
| # dictates you need to call new_group for both. |
| self.assertEqual(len(pgs), 2) |
| self.assertEqual(pgs[1][0], pg0 if self.rank < 2 else pg1) |
| |
| |
| def with_comms(func=None): |
| if func is None: |
| return partial( |
| with_comms, |
| ) |
| |
| @wraps(func) |
| def wrapper(self, *args, **kwargs): |
| self.init_comms() |
| func(self, *args, **kwargs) |
| self.destroy_comms() |
| |
| return wrapper |
| |
| |
| class CollectiveHooks: |
| @property |
| def world_size(self) -> int: |
| return 4 |
| |
| def _collective_hooks(self): |
| # it's ok to access them directly since there's a single bg thread poking at them. |
| starts = [] |
| ends = [] |
| cv = threading.Condition() |
| |
| def coll_start(status): |
| starts.append(status) |
| print(f"col_start {len(starts)} rank{self.rank}") |
| |
| def coll_end(status): |
| ends.append(status) |
| print(f"col_end {len(ends)} rank{self.rank}") |
| if len(ends) == 2: |
| with cv: |
| cv.notify() |
| |
| dhooks.register_collective_start_hook(coll_start) |
| dhooks.register_collective_end_hook(coll_end) |
| |
| tensor = torch.ones([2, 3]).to(self.device) * self.rank |
| tensor_list = [torch.empty_like(tensor) for _ in range(self.world_size)] |
| |
| dist.all_gather(tensor_list, tensor) |
| |
| tensor2 = torch.ones([2, 3]).to(self.device) * self.rank |
| dist.all_reduce(tensor2) |
| |
| with cv: |
| cv.wait(1) |
| |
| default_pg_name = dist.group.WORLD.group_name |
| self.assertEqual(2, len(starts)) |
| self.assertEqual(2, len(ends)) |
| |
| def check_op(idx, coll_name): |
| self.assertEqual(default_pg_name, starts[idx].pg_name) |
| self.assertEqual(self.backend_name, starts[idx].backend) |
| self.assertGreaterEqual(starts[idx].sequence_number, 0) |
| self.assertGreaterEqual(starts[idx].timestamp, 0) |
| self.assertEqual(coll_name, starts[idx].operation) |
| |
| self.assertEqual(default_pg_name, ends[idx].pg_name) |
| self.assertEqual(self.backend_name, ends[idx].backend) |
| |
| self.assertEqual(starts[idx].sequence_number, ends[idx].sequence_number) |
| self.assertLessEqual(starts[idx].timestamp, ends[idx].timestamp) |
| self.assertEqual(coll_name, ends[idx].operation) |
| |
| check_op(0, "ALLGATHER") |
| check_op(1, "ALLREDUCE") |
| |
| |
| class GlooHooks(MultiProcessTestCase, CollectiveHooks): |
| def setUp(self) -> None: |
| super().setUp() |
| self._spawn_processes() |
| |
| def tearDown(self): |
| super().tearDown() |
| try: |
| os.remove(self.file_name) |
| except OSError: |
| pass |
| |
| def init_comms(self): |
| dist.init_process_group( |
| backend="gloo", |
| rank=self.rank, |
| world_size=self.world_size, |
| store=dist.FileStore(self.file_name, self.world_size), |
| ) |
| |
| def destroy_comms(self): |
| dist.destroy_process_group() |
| |
| @property |
| def backend_name(self): |
| return "gloo" |
| |
| @property |
| def device(self): |
| return "cpu" |
| |
| @with_comms |
| def test_collective_hooks(self): |
| self._collective_hooks() |
| |
| |
| class NcclHooks(MultiProcessTestCase, CollectiveHooks): |
| def setUp(self) -> None: |
| super().setUp() |
| self._spawn_processes() |
| |
| def tearDown(self): |
| super().tearDown() |
| try: |
| os.remove(self.file_name) |
| except OSError: |
| pass |
| |
| def init_comms(self): |
| dist.init_process_group( |
| backend="nccl", |
| rank=self.rank, |
| world_size=self.world_size, |
| store=dist.FileStore(self.file_name, self.world_size), |
| ) |
| |
| def destroy_comms(self): |
| dist.destroy_process_group() |
| |
| @property |
| def backend_name(self): |
| return "nccl" |
| |
| @property |
| def device(self): |
| return f"cuda:{self.rank}" |
| |
| @skip_if_lt_x_gpu(4) |
| @with_comms |
| def test_collective_hooks(self): |
| self._collective_hooks() |
| |
| |
| class SingleRankTests(TestCase): |
| def setUp(self) -> None: |
| super().setUp() |
| self.rank = 0 |
| self.file_name = tempfile.NamedTemporaryFile(delete=False).name |
| dist.init_process_group( |
| backend="gloo", |
| rank=0, |
| world_size=1, |
| store=dist.FileStore(self.file_name, 1), |
| ) |
| |
| def tearDown(self) -> None: |
| dist.destroy_process_group() |
| |
| def test_queue_overflow(self) -> None: |
| cv_done_colls = threading.Condition() |
| cv_done_cb = threading.Condition() |
| colls_done = False |
| starts = [] |
| status_with_dropped = None |
| |
| def coll_start(status: dhooks.CollectiveStatus): |
| starts.append(status) |
| with cv_done_colls: |
| while not colls_done: |
| cv_done_colls.wait() |
| if status.drop_count > 0: |
| nonlocal status_with_dropped |
| status_with_dropped = status |
| with cv_done_cb: |
| cv_done_cb.notify() |
| |
| dhooks.register_collective_start_hook(coll_start) |
| |
| # native limit is 512 |
| for i in range(600): |
| dist.all_reduce(torch.ones([2, 3])) |
| colls_done = True |
| with cv_done_colls: |
| cv_done_colls.notify() |
| |
| with cv_done_cb: |
| cv_done_cb.wait(10) |
| |
| self.assertTrue(status_with_dropped is not None) |
| self.assertTrue(status_with_dropped.drop_count > 0) |
| |
| |
| if __name__ == "__main__": |
| assert ( |
| not torch.cuda._initialized |
| ), "test_distributed must not have initialized CUDA context on main process" |
| |
| run_tests() |