| # Owner(s): ["oncall: distributed"] | 
 |  | 
 | import os | 
 | import sys | 
 | from functools import wraps, partial | 
 |  | 
 | import torch | 
 | import torch.distributed as dist | 
 |  | 
 | if not dist.is_available(): | 
 |     print("Distributed not available, skipping tests", file=sys.stderr) | 
 |     sys.exit(0) | 
 |  | 
 | from torch.testing._internal.common_distributed import ( | 
 |     MultiProcessTestCase, | 
 |     TEST_SKIPS | 
 | ) | 
 |  | 
 | from torch.testing._internal.common_utils import ( | 
 |     run_tests, | 
 |     TEST_WITH_DEV_DBG_ASAN, | 
 | ) | 
 |  | 
 | if TEST_WITH_DEV_DBG_ASAN: | 
 |     print("Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr) | 
 |     sys.exit(0) | 
 |  | 
 | BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO | 
 | WORLD_SIZE = min(4, max(2, torch.cuda.device_count())) | 
 |  | 
 | def with_comms(func=None): | 
 |     if func is None: | 
 |         return partial( | 
 |             with_comms, | 
 |         ) | 
 |  | 
 |     @wraps(func) | 
 |     def wrapper(self, *args, **kwargs): | 
 |         if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size: | 
 |             sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) | 
 |         self.dist_init() | 
 |         func(self) | 
 |         self.destroy_comms() | 
 |     return wrapper | 
 |  | 
 | class TestObjectCollectives(MultiProcessTestCase): | 
 |     def setUp(self): | 
 |         super().setUp() | 
 |         os.environ["WORLD_SIZE"] = str(self.world_size) | 
 |         os.environ["BACKEND"] = BACKEND | 
 |         self._spawn_processes() | 
 |  | 
 |     @property | 
 |     def device(self): | 
 |         return torch.device(self.rank) if BACKEND == dist.Backend.NCCL \ | 
 |             else torch.device("cpu") | 
 |  | 
 |     @property | 
 |     def world_size(self): | 
 |         return WORLD_SIZE | 
 |  | 
 |     @property | 
 |     def process_group(self): | 
 |         return dist.group.WORLD | 
 |  | 
 |     def destroy_comms(self): | 
 |         # Wait for all ranks to reach here before starting shutdown. | 
 |         dist.barrier() | 
 |         dist.destroy_process_group() | 
 |  | 
 |     def dist_init(self): | 
 |         dist.init_process_group( | 
 |             backend=BACKEND, | 
 |             world_size=self.world_size, | 
 |             rank=self.rank, | 
 |             init_method=f"file://{self.file_name}", | 
 |         ) | 
 |  | 
 |         # set device for nccl pg for collectives | 
 |         if BACKEND == "nccl": | 
 |             torch.cuda.set_device(self.rank) | 
 |  | 
 |     @with_comms() | 
 |     def test_all_gather_object(self): | 
 |         output = [None] * dist.get_world_size() | 
 |         dist.all_gather_object( | 
 |             object_list=output, | 
 |             obj=self.rank) | 
 |  | 
 |         for i, v in enumerate(output): | 
 |             self.assertEqual(i, v, f"rank: {self.rank}") | 
 |  | 
 |     @with_comms() | 
 |     def test_gather_object(self): | 
 |         output = [None] * dist.get_world_size() if self.rank == 0 else None | 
 |         dist.gather_object( | 
 |             obj=self.rank, | 
 |             object_gather_list=output) | 
 |  | 
 |         if self.rank == 0: | 
 |             for i, v in enumerate(output): | 
 |                 self.assertEqual(i, v, f"rank: {self.rank}") | 
 |  | 
 |  | 
 |     @with_comms() | 
 |     def test_broadcast_object_list(self): | 
 |         val = 99 if self.rank == 0 else None | 
 |         object_list = [val] * dist.get_world_size() | 
 |         # TODO test with broadcast_object_list's device argument | 
 |         dist.broadcast_object_list(object_list=object_list) | 
 |  | 
 |         self.assertEqual(99, object_list[0]) | 
 |  | 
 |     @with_comms() | 
 |     def test_scatter_object_list(self): | 
 |         input_list = list(range(dist.get_world_size())) if self.rank == 0 else None | 
 |         output_list = [None] | 
 |         dist.scatter_object_list( | 
 |             scatter_object_output_list=output_list, | 
 |             scatter_object_input_list=input_list) | 
 |  | 
 |         self.assertEqual(self.rank, output_list[0]) | 
 |  | 
 |     # Test Object Collectives With Sub Pg | 
 |  | 
 |     def setup_sub_pg(self): | 
 |         rank = dist.get_rank() | 
 |         base_rank = rank - (rank % 2) | 
 |         ranks = [base_rank, base_rank + 1] | 
 |         my_pg = dist.new_group(ranks, use_local_synchronization=True) | 
 |         return rank, ranks, my_pg | 
 |  | 
 |     @with_comms() | 
 |     def test_subpg_scatter_object(self): | 
 |         rank, ranks, my_pg = self.setup_sub_pg() | 
 |         out_list = [None] | 
 |         dist.scatter_object_list(out_list, ranks, src=ranks[0], group=my_pg) | 
 |         self.assertEqual(rank, out_list[0]) | 
 |  | 
 |     @with_comms() | 
 |     def test_subpg_all_gather_object(self): | 
 |         rank, ranks, my_pg = self.setup_sub_pg() | 
 |         out_list = [None] * len(ranks) | 
 |         dist.all_gather_object(out_list, rank, group=my_pg) | 
 |         self.assertEqual(ranks, out_list) | 
 |  | 
 |     @with_comms() | 
 |     def test_subpg_gather_object(self): | 
 |         rank, ranks, my_pg = self.setup_sub_pg() | 
 |         out_list = [None] * len(ranks) if rank == ranks[0] else None | 
 |         dist.gather_object(rank, out_list, dst=ranks[0], group=my_pg) | 
 |         if rank == ranks[0]: | 
 |             self.assertEqual(ranks, out_list) | 
 |  | 
 |     @with_comms() | 
 |     def test_subpg_broadcast_object(self): | 
 |         rank, ranks, my_pg = self.setup_sub_pg() | 
 |         out_list = [None] | 
 |         if rank == ranks[0]: | 
 |             out_list[0] = rank | 
 |         dist.broadcast_object_list(out_list, src=ranks[0], group=my_pg) | 
 |         self.assertEqual(ranks[0], out_list[0]) | 
 |  | 
 | if __name__ == "__main__": | 
 |     run_tests() |