|  | # Owner(s): ["oncall: distributed"] | 
|  |  | 
|  | import os | 
|  | import sys | 
|  | import tempfile | 
|  | import time | 
|  | from datetime import timedelta | 
|  | from sys import platform | 
|  |  | 
|  | import torch | 
|  | import torch.distributed as dist | 
|  | import torch.distributed.rpc as rpc | 
|  |  | 
|  | if not dist.is_available(): | 
|  | print("torch.distributed not available, skipping tests", file=sys.stderr) | 
|  | sys.exit(0) | 
|  |  | 
|  | import torch.testing._internal.common_utils as common | 
|  | from torch._six import string_classes | 
|  | from torch.testing._internal.common_distributed import ( | 
|  | skip_if_win32, | 
|  | create_tcp_store | 
|  | ) | 
|  | from torch.testing._internal.common_utils import ( | 
|  | TestCase, | 
|  | load_tests, | 
|  | run_tests, | 
|  | retry_on_connect_failures, | 
|  | ADDRESS_IN_USE, | 
|  | CONNECT_TIMEOUT, | 
|  | ) | 
|  |  | 
|  | # load_tests from common_utils is used to automatically filter tests for | 
|  | # sharding on sandcastle. This line silences flake warnings | 
|  | load_tests = load_tests | 
|  |  | 
|  | if platform == "darwin": | 
|  | LOOPBACK = "lo0" | 
|  | else: | 
|  | LOOPBACK = "lo" | 
|  |  | 
|  | DEFAULT_HOSTNAME = "localhost" | 
|  |  | 
|  | torch.backends.cuda.matmul.allow_tf32 = False | 
|  |  | 
|  |  | 
|  | def gpus_for_rank(world_size): | 
|  | """Multigpu tests are designed to simulate the multi nodes with multi | 
|  | GPUs on each node. Nccl backend requires equal #GPUs in each process. | 
|  | On a single node, all visible GPUs are evenly | 
|  | divided to subsets, each process only uses a subset. | 
|  | """ | 
|  | visible_devices = list(range(torch.cuda.device_count())) | 
|  | gpus_per_process = torch.cuda.device_count() // world_size | 
|  | gpus_for_rank = [] | 
|  | for rank in range(world_size): | 
|  | gpus_for_rank.append( | 
|  | visible_devices[rank * gpus_per_process: (rank + 1) * gpus_per_process] | 
|  | ) | 
|  | return gpus_for_rank | 
|  |  | 
|  |  | 
|  | class StoreTestBase(object): | 
|  | def _create_store(self, i): | 
|  | raise RuntimeError("not implemented") | 
|  |  | 
|  | def _test_set_get(self, fs): | 
|  | fs.add("key", 1) | 
|  | fs.add("key", 2) | 
|  | fs.add("key", 3) | 
|  | fs.set("key0", "value0") | 
|  | fs.add("key3", 1) | 
|  | fs.set("key1", "value1") | 
|  | fs.add("key3", 2) | 
|  | fs.set("key2", "value2") | 
|  | fs.add("key3", 3) | 
|  | fs.add("key3", 4) | 
|  | fs.add("key3", 5) | 
|  | fs.add("key3", 6) | 
|  | self.assertEqual(fs.num_keys(), self.num_keys_total) | 
|  | self.assertEqual(b"6", fs.get("key")) | 
|  | self.assertEqual(b"value0", fs.get("key0")) | 
|  | self.assertEqual(b"value1", fs.get("key1")) | 
|  | self.assertEqual(b"value2", fs.get("key2")) | 
|  | self.assertEqual(b"21", fs.get("key3")) | 
|  |  | 
|  | fs.set("-key3", "7") | 
|  | self.assertEqual(b"7", fs.get("-key3")) | 
|  | fs.delete_key("-key3") | 
|  | self.assertEqual(fs.num_keys(), self.num_keys_total) | 
|  |  | 
|  | def test_set_get(self): | 
|  | self._test_set_get(self._create_store()) | 
|  |  | 
|  | def _test_compare_set(self, store): | 
|  | missing_key_result = store.compare_set("cs_key0", "wrong_old_value", "new_value0") | 
|  | self.assertEqual(b"wrong_old_value", missing_key_result) | 
|  |  | 
|  | store.set("cs_key0", "value0") | 
|  | self.assertEqual(b"value0", store.get("cs_key0")) | 
|  | old_value_result = store.compare_set("cs_key0", "wrong_old_value", "new_value0") | 
|  | self.assertEqual(b"value0", old_value_result) | 
|  | self.assertEqual(b"value0", store.get("cs_key0")) | 
|  | new_value_result = store.compare_set("cs_key0", "value0", "new_value0") | 
|  | self.assertEqual(b"new_value0", new_value_result) | 
|  | self.assertEqual(b"new_value0", store.get("cs_key0")) | 
|  | empty_old_value_result = store.compare_set("cs_key1", "", "new_value1") | 
|  | self.assertEqual(b"new_value1", empty_old_value_result) | 
|  | self.assertEqual(b"new_value1", store.get("cs_key1")) | 
|  |  | 
|  | def test_compare_set(self): | 
|  | self._test_compare_set(self._create_store()) | 
|  |  | 
|  | # This is the number of keys used in test_set_get. Adding this as a class | 
|  | # property instead of hardcoding in the test since some Store | 
|  | # implementations will have differing number of keys. In the base case, | 
|  | # there will be 5 keys: key, key0, key1, key2, key3. | 
|  | @property | 
|  | def num_keys_total(self): | 
|  | return 5 | 
|  |  | 
|  |  | 
|  | class FileStoreTest(TestCase, StoreTestBase): | 
|  | def setUp(self): | 
|  | super(FileStoreTest, self).setUp() | 
|  | self.file = tempfile.NamedTemporaryFile(delete=False) | 
|  |  | 
|  | def _create_store(self): | 
|  | store = dist.FileStore(self.file.name, 1) | 
|  | store.set_timeout(timedelta(seconds=300)) | 
|  | return store | 
|  |  | 
|  | def test_init_pg_and_rpc_with_same_file(self): | 
|  | file = tempfile.NamedTemporaryFile(delete=False) | 
|  | # Init RPC using file | 
|  | rpc_backend_options = rpc.TensorPipeRpcBackendOptions() | 
|  | rpc_backend_options.init_method = f"file://{file.name}" | 
|  | rpc.init_rpc("worker", rank=0, world_size=1, rpc_backend_options=rpc_backend_options) | 
|  |  | 
|  | # Init PG using file | 
|  | dist.init_process_group("gloo", rank=0, world_size=1, init_method=f"file://{file.name}") | 
|  | dist.destroy_process_group() | 
|  | assert os.path.exists(file.name) | 
|  |  | 
|  | rpc.shutdown() | 
|  | os.remove(file.name) | 
|  |  | 
|  | def test_refcount(self): | 
|  | file = tempfile.NamedTemporaryFile(delete=False) | 
|  | store = dist.FileStore(file.name, 1) | 
|  | store2 = dist.FileStore(file.name, 1) | 
|  |  | 
|  | del store | 
|  | assert os.path.exists(file.name) | 
|  | del store2 | 
|  | assert not os.path.exists(file.name) | 
|  |  | 
|  | @property | 
|  | def num_keys_total(self): | 
|  | return 6 | 
|  |  | 
|  |  | 
|  | @skip_if_win32() | 
|  | class HashStoreTest(TestCase, StoreTestBase): | 
|  | def setUp(self): | 
|  | super(HashStoreTest, self).setUp() | 
|  |  | 
|  | def _create_store(self): | 
|  | store = dist.HashStore() | 
|  | store.set_timeout(timedelta(seconds=300)) | 
|  | return store | 
|  |  | 
|  | class PrefixStoreTest(TestCase): | 
|  | def setUp(self): | 
|  | # delete is false as FileStore will automatically clean up the file | 
|  | self.file = tempfile.NamedTemporaryFile(delete=False) | 
|  |  | 
|  | def test_get_underlying_store(self): | 
|  | tcp_store = dist.TCPStore(host_name=DEFAULT_HOSTNAME, port=0, world_size=1, is_master=True) | 
|  | hash_store = dist.HashStore() | 
|  | file_store = dist.FileStore(self.file.name, world_size=1) | 
|  | for store in [tcp_store, hash_store, file_store]: | 
|  | with self.subTest(f"Testing getting underlying_store for {type(store)}"): | 
|  | prefix_store = dist.PrefixStore("prefix", store) | 
|  | self.assertEqual(prefix_store.underlying_store, store) | 
|  |  | 
|  | class PrefixFileStoreTest(TestCase, StoreTestBase): | 
|  | def setUp(self): | 
|  | super(PrefixFileStoreTest, self).setUp() | 
|  | self.file = tempfile.NamedTemporaryFile(delete=False) | 
|  | self.filestore = dist.FileStore(self.file.name, 1) | 
|  | self.prefix = "test_prefix" | 
|  | self.filestore.set_timeout(timedelta(seconds=300)) | 
|  |  | 
|  | def _create_store(self): | 
|  | return dist.PrefixStore(self.prefix, self.filestore) | 
|  |  | 
|  | @property | 
|  | def num_keys_total(self): | 
|  | return 6 | 
|  |  | 
|  |  | 
|  | class TCPStoreTest(TestCase, StoreTestBase): | 
|  | def _create_store(self): | 
|  | store = create_tcp_store() | 
|  | store.set_timeout(timedelta(seconds=300)) | 
|  | return store | 
|  |  | 
|  | def test_address_already_in_use(self): | 
|  | err_msg_reg = "^The server socket has failed to listen on any local " | 
|  | with self.assertRaisesRegex(RuntimeError, err_msg_reg): | 
|  | addr = DEFAULT_HOSTNAME | 
|  | port = common.find_free_port() | 
|  |  | 
|  | # Use noqa to silence flake8. | 
|  | # Need to store in an unused variable here to ensure the first | 
|  | # object is not destroyed before the second object is created. | 
|  | store1 = dist.TCPStore(addr, port, 1, True)  # noqa: F841 | 
|  | store2 = dist.TCPStore(addr, port, 1, True)  # noqa: F841 | 
|  |  | 
|  | @retry_on_connect_failures | 
|  | def test_multitenancy(self): | 
|  | addr = DEFAULT_HOSTNAME | 
|  | port = common.find_free_port() | 
|  |  | 
|  | # Use noqa to silence flake8. | 
|  | # Need to store in an unused variable here to ensure the first | 
|  | # object is not destroyed before the second object is created. | 
|  | store1 = dist.TCPStore(addr, port, 1, True, multi_tenant=True)  # type: ignore[call-arg] # noqa: F841 | 
|  | store2 = dist.TCPStore(addr, port, 1, True, multi_tenant=True)  # type: ignore[call-arg] # noqa: F841 | 
|  |  | 
|  | @skip_if_win32() | 
|  | @retry_on_connect_failures | 
|  | def test_init_pg_and_rpc_with_same_socket(self): | 
|  | addr = DEFAULT_HOSTNAME | 
|  | port = common.find_free_port() | 
|  |  | 
|  | os.environ["MASTER_ADDR"] = addr | 
|  | os.environ["MASTER_PORT"] = str(port) | 
|  |  | 
|  | # We internally use a multi-tenant TCP store. Both PG and RPC should successfully | 
|  | # initialize even when using the same socket address. | 
|  |  | 
|  | dist.init_process_group( | 
|  | backend="gloo", | 
|  | init_method="env://", | 
|  | rank=0, | 
|  | world_size=1, | 
|  | ) | 
|  |  | 
|  | backend_opts = rpc.TensorPipeRpcBackendOptions( | 
|  | init_method=f"tcp://{addr}:{port}" | 
|  | ) | 
|  | rpc.init_rpc( | 
|  | name="worker0", | 
|  | rank=0, | 
|  | world_size=1, | 
|  | rpc_backend_options=backend_opts, | 
|  | ) | 
|  |  | 
|  | rpc.shutdown() | 
|  |  | 
|  | # The TCPStore has 6 keys in test_set_get. It contains the 5 keys added by | 
|  | # the user and one additional key used for coordinate all the workers. | 
|  | @property | 
|  | def num_keys_total(self): | 
|  | return 6 | 
|  |  | 
|  | def _test_numkeys_delkeys(self, fs): | 
|  | # We start off with one init key in the store to coordinate workers | 
|  | self.assertEqual(fs.num_keys(), 1) | 
|  | fs.add("key", 1) | 
|  | fs.add("key", 2) | 
|  | fs.add("key", 3) | 
|  | fs.set("key0", "value0") | 
|  | fs.add("key3", 1) | 
|  | fs.set("key1", "value1") | 
|  | self.assertEqual(fs.num_keys(), 5) | 
|  | fs.delete_key("key") | 
|  | self.assertEqual(fs.num_keys(), 4) | 
|  | fs.set_timeout(timedelta(seconds=2)) | 
|  | with self.assertRaises(RuntimeError): | 
|  | fs.get("key") | 
|  | fs.delete_key("key0") | 
|  | fs.delete_key("key3") | 
|  | self.assertEqual(fs.num_keys(), 2) | 
|  | fs.set("key4", "value2") | 
|  | self.assertEqual(fs.num_keys(), 3) | 
|  | self.assertEqual(b"value1", fs.get("key1")) | 
|  | self.assertEqual(b"value2", fs.get("key4")) | 
|  |  | 
|  | def test_numkeys_delkeys(self): | 
|  | self._test_numkeys_delkeys(self._create_store()) | 
|  |  | 
|  | def _create_client(self, index, addr, port, world_size): | 
|  | client_store = dist.TCPStore(addr, port, world_size=world_size, timeout=timedelta(seconds=10)) | 
|  | self.assertEqual("value".encode(), client_store.get("key")) | 
|  | client_store.set(f"new_key{index}", f"new_value{index}") | 
|  | self.assertEqual(f"next_value{index}".encode(), | 
|  | client_store.compare_set(f"new_key{index}", f"new_value{index}", f"next_value{index}")) | 
|  |  | 
|  | def _multi_worker_helper(self, world_size): | 
|  | addr = DEFAULT_HOSTNAME | 
|  | server_store = create_tcp_store(addr, world_size, wait_for_workers=False) | 
|  | server_store.set("key", "value") | 
|  | port = server_store.port | 
|  |  | 
|  | num_indices = world_size if world_size else 1 | 
|  | for i in range(num_indices): | 
|  | self._create_client(i, addr, port, world_size) | 
|  |  | 
|  | def test_multi_worker_with_fixed_world_size(self): | 
|  | self._multi_worker_helper(5) | 
|  |  | 
|  | def test_multi_worker_with_nonfixed_world_size(self): | 
|  | self._multi_worker_helper(None) | 
|  |  | 
|  | class PrefixTCPStoreTest(TestCase, StoreTestBase): | 
|  | def setUp(self): | 
|  | super(PrefixTCPStoreTest, self).setUp() | 
|  | self.tcpstore = create_tcp_store() | 
|  | self.prefix = "test_prefix" | 
|  | self.tcpstore.set_timeout(timedelta(seconds=300)) | 
|  |  | 
|  | def _create_store(self): | 
|  | return dist.PrefixStore(self.prefix, self.tcpstore) | 
|  |  | 
|  | # The PrefixTCPStore has 6 keys in test_set_get. It contains the 5 keys | 
|  | # added by the user and one additional key used for coordinate all the | 
|  | # workers. | 
|  | @property | 
|  | def num_keys_total(self): | 
|  | return 6 | 
|  |  | 
|  |  | 
|  | class MyPythonStore(dist.Store): | 
|  | def __init__(self): | 
|  | super(MyPythonStore, self).__init__() | 
|  | self.store = {} | 
|  |  | 
|  | def set(self, key, value): | 
|  | if not isinstance(key, string_classes): | 
|  | raise AssertionError("Expected set to be called with string key") | 
|  | if type(value) is not bytes: | 
|  | raise AssertionError("Expected set to be called with bytes value") | 
|  | self.store[key] = value | 
|  |  | 
|  | def get(self, key): | 
|  | value = self.store.get(key, b"") | 
|  | if type(value) is not bytes: | 
|  | raise AssertionError("Expected get to return bytes value") | 
|  | return value | 
|  |  | 
|  | def add(self, key, value): | 
|  | new = int(self.store.get(key, 0)) + value | 
|  | self.set(key, bytes(str(new).encode("utf-8"))) | 
|  | return new | 
|  |  | 
|  |  | 
|  | class PythonStoreTest(TestCase): | 
|  | def setUp(self): | 
|  | super(PythonStoreTest, self).setUp() | 
|  |  | 
|  | def test_set_get(self): | 
|  | # If we were to inherit from StoreTestBase and try to use | 
|  | # its test_set_get function, we would exercise the Python | 
|  | # API directly, instead of going through the C++ trampoline. | 
|  | # We care about testing the C++ trampoline, so run the | 
|  | # equivalent of StoreTestBase.test_set_get from C++. | 
|  | # See `torch/csrc/distributed/c10d/init.cpp` for the definition | 
|  | # of this test function. | 
|  | dist._test_python_store(MyPythonStore()) | 
|  |  | 
|  |  | 
|  | class RendezvousTest(TestCase): | 
|  | def test_unknown_handler(self): | 
|  | with self.assertRaisesRegex(RuntimeError, "^No rendezvous handler"): | 
|  | dist.rendezvous("invalid://") | 
|  |  | 
|  | def test_url_with_node_params(self): | 
|  | with self.assertRaisesRegex(AssertionError, "has node-specific arguments"): | 
|  | dist.rendezvous("file://foo?rank=12&world_size=16", 12, 16) | 
|  |  | 
|  |  | 
|  | class RendezvousEnvTest(TestCase): | 
|  | @retry_on_connect_failures | 
|  | def test_nominal(self): | 
|  | os.environ["WORLD_SIZE"] = "1" | 
|  | os.environ["MASTER_ADDR"] = "127.0.0.1" | 
|  | os.environ["MASTER_PORT"] = str(common.find_free_port()) | 
|  |  | 
|  | # Single rank | 
|  | os.environ["RANK"] = "0" | 
|  | gen0 = dist.rendezvous("env://") | 
|  | store0, rank0, size0 = next(gen0) | 
|  | self.assertEqual(0, rank0) | 
|  | self.assertEqual(1, size0) | 
|  |  | 
|  | store0.set("key0", "value0") | 
|  |  | 
|  | # check with get | 
|  | self.assertEqual(b"value0", store0.get("key0")) | 
|  |  | 
|  |  | 
|  | class RendezvousFileTest(TestCase): | 
|  | def test_common_errors(self): | 
|  | with self.assertRaisesRegex(ValueError, "path missing"): | 
|  | gen = dist.rendezvous("file://?rank=0&world_size=1") | 
|  | next(gen) | 
|  | with self.assertRaisesRegex(ValueError, "rank parameter missing"): | 
|  | gen = dist.rendezvous("file:///tmp/foo?world_size=1") | 
|  | next(gen) | 
|  | with self.assertRaisesRegex(ValueError, "size parameter missing"): | 
|  | gen = dist.rendezvous("file:///tmp/foo?rank=0") | 
|  | next(gen) | 
|  |  | 
|  | def test_nominal(self): | 
|  | with tempfile.NamedTemporaryFile(delete=False) as file: | 
|  | url = f'file:///{file.name.replace(os.path.sep, "/")}?world_size=2' | 
|  | gen0 = dist.rendezvous(url + "&rank=0") | 
|  | store0, rank0, size0 = next(gen0) | 
|  | self.assertEqual(0, rank0) | 
|  | self.assertEqual(2, size0) | 
|  | gen1 = dist.rendezvous(url + "&rank=1") | 
|  | store1, rank1, size1 = next(gen1) | 
|  | self.assertEqual(1, rank1) | 
|  | self.assertEqual(2, size1) | 
|  |  | 
|  | # Set value on both stores | 
|  | store0.set("key0", "value0") | 
|  | store1.set("key1", "value1") | 
|  |  | 
|  | # Cross check with get | 
|  | self.assertEqual(b"value0", store1.get("key0")) | 
|  | self.assertEqual(b"value1", store0.get("key1")) | 
|  |  | 
|  |  | 
|  | @skip_if_win32() | 
|  | class RendezvousTCPTest(TestCase): | 
|  | def create_tcp_url(self): | 
|  | addr = DEFAULT_HOSTNAME | 
|  | port = common.find_free_port() | 
|  | url = "tcp://%s:%d?world_size=%d" % (addr, port, 1) | 
|  | return url | 
|  |  | 
|  | def test_common_errors(self): | 
|  | with self.assertRaisesRegex(ValueError, "port number missing"): | 
|  | gen = dist.rendezvous("tcp://127.0.0.1?rank=0&world_size=1") | 
|  | next(gen) | 
|  | with self.assertRaisesRegex(ValueError, "rank parameter missing"): | 
|  | gen = dist.rendezvous("tcp://127.0.0.1:23456?world_size=1") | 
|  | next(gen) | 
|  | with self.assertRaisesRegex(ValueError, "size parameter missing"): | 
|  | gen = dist.rendezvous("tcp://127.0.0.1:23456?rank=0") | 
|  | next(gen) | 
|  |  | 
|  | def test_dns_timeout(self): | 
|  | with self.assertRaisesRegex(TimeoutError, "client socket has timed out after.*dnsnotexist"): | 
|  | gen = dist.rendezvous( | 
|  | "tcp://dnsnotexist:23456?world_size=2&rank=0", | 
|  | timeout=timedelta(seconds=1), | 
|  | ) | 
|  | next(gen) | 
|  |  | 
|  | @retry_on_connect_failures | 
|  | def test_nominal(self): | 
|  | url = self.create_tcp_url() | 
|  | gen0 = dist.rendezvous(url + "&rank=0") | 
|  | store0, rank0, size0 = next(gen0) | 
|  | self.assertEqual(0, rank0) | 
|  | self.assertEqual(1, size0) | 
|  |  | 
|  | # Set value on the single store | 
|  | store0.set("key0", "value0") | 
|  |  | 
|  | # check with get | 
|  | self.assertEqual(b"value0", store0.get("key0")) | 
|  |  | 
|  | @retry_on_connect_failures(connect_errors=(CONNECT_TIMEOUT, ADDRESS_IN_USE)) | 
|  | def test_tcp_store_timeout_set(self): | 
|  | url = self.create_tcp_url() | 
|  | test_store_timeout = timedelta(seconds=10) | 
|  | gen0 = dist.rendezvous(url + "&rank=0", timeout=test_store_timeout) | 
|  | store0, rank0, size0 = next(gen0) | 
|  | # this should time out in 10s. If the timeout passed into rendezvous was | 
|  | # not respected, it will take much longer to timeout. | 
|  | start = time.time() | 
|  | with self.assertRaisesRegex(RuntimeError, "Timeout"): | 
|  | store0.get("nonexistant key") | 
|  |  | 
|  | end = time.time() | 
|  | time_diff = end - start | 
|  | self.assertGreater(test_store_timeout.seconds * 10, time_diff) | 
|  |  | 
|  |  | 
|  | if __name__ == "__main__": | 
|  | assert ( | 
|  | not torch.cuda._initialized | 
|  | ), "test_distributed must not have initialized CUDA context on main process" | 
|  |  | 
|  | run_tests() |