blob: 6a0fe33cd8ad62363f9f8d3a6d6f3f7feec2a195 [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
import sys
import torch.distributed as dist
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
from torch.testing._internal.common_distributed import (
spawn_threads_and_init_comms,
MultiThreadedTestCase
)
from torch.testing._internal.common_utils import TestCase, run_tests
DEFAULT_WORLD_SIZE = 4
class TestObjectCollectivesWithWrapper(TestCase):
@spawn_threads_and_init_comms(world_size=4)
def test_broadcast_object_list(self):
val = 99 if dist.get_rank() == 0 else None
object_list = [val] * dist.get_world_size()
dist.broadcast_object_list(object_list=object_list)
self.assertEqual(99, object_list[0])
class TestObjectCollectivesWithBaseClass(MultiThreadedTestCase):
@property
def world_size(self):
return 4
def test_broadcast_object_list(self):
val = 99 if dist.get_rank() == 0 else None
object_list = [val] * dist.get_world_size()
print(f"{dist.get_rank()} -> {dist.get_world_size()}")
dist.broadcast_object_list(object_list=object_list)
self.assertEqual(99, object_list[0])
def test_something_else(self):
pass
if __name__ == "__main__":
run_tests()