| # Owner(s): ["oncall: distributed"] |
| |
| import datetime |
| import os |
| import socket |
| import sys |
| import tempfile |
| import threading |
| import time |
| from datetime import timedelta |
| from sys import platform |
| |
| import torch |
| import torch.distributed as dist |
| import torch.distributed.distributed_c10d as c10d |
| import torch.distributed.rpc as rpc |
| from torch.distributed import DistError, DistNetworkError, DistStoreError |
| from torch.testing._internal.common_distributed import MultiThreadedTestCase |
| from torch.testing._internal.common_utils import ( |
| instantiate_parametrized_tests, |
| parametrize, |
| ) |
| |
| if not dist.is_available(): |
| print("torch.distributed not available, skipping tests", file=sys.stderr) |
| sys.exit(0) |
| |
| import torch.testing._internal.common_utils as common |
| from torch.testing._internal.common_distributed import ( |
| create_tcp_store, |
| skip_if_win32, |
| tp_transports, |
| ) |
| from torch.testing._internal.common_utils import ( |
| ADDRESS_IN_USE, |
| CONNECT_TIMEOUT, |
| load_tests, |
| retry_on_connect_failures, |
| run_tests, |
| TestCase, |
| ) |
| |
| # load_tests from common_utils is used to automatically filter tests for |
| # sharding on sandcastle. This line silences flake warnings |
| load_tests = load_tests |
| |
| if platform == "darwin": |
| LOOPBACK = "lo0" |
| else: |
| LOOPBACK = "lo" |
| |
| DEFAULT_HOSTNAME = "localhost" |
| |
| torch.backends.cuda.matmul.allow_tf32 = False |
| |
| |
| def gpus_for_rank(world_size): |
| """Multigpu tests are designed to simulate the multi nodes with multi |
| GPUs on each node. Nccl backend requires equal #GPUs in each process. |
| On a single node, all visible GPUs are evenly |
| divided to subsets, each process only uses a subset. |
| """ |
| visible_devices = list(range(torch.cuda.device_count())) |
| gpus_per_process = torch.cuda.device_count() // world_size |
| gpus_for_rank = [] |
| for rank in range(world_size): |
| gpus_for_rank.append( |
| visible_devices[rank * gpus_per_process : (rank + 1) * gpus_per_process] |
| ) |
| return gpus_for_rank |
| |
| |
| class StoreTestBase: |
| def _create_store(self, i): |
| raise RuntimeError("not implemented") |
| |
| def _test_set_get_check(self, fs): |
| fs.add("key", 1) |
| fs.add("key", 2) |
| fs.add("key", 3) |
| fs.set("key0", "value0") |
| fs.add("key3", 1) |
| fs.set("key1", "value1") |
| fs.add("key3", 2) |
| fs.set("key2", "value2") |
| fs.add("key3", 3) |
| fs.add("key3", 4) |
| fs.add("key3", 5) |
| fs.add("key3", 6) |
| self.assertEqual(fs.num_keys(), self.num_keys_total) |
| self.assertEqual(b"6", fs.get("key")) |
| self.assertEqual(b"value0", fs.get("key0")) |
| self.assertEqual(b"value1", fs.get("key1")) |
| self.assertEqual(b"value2", fs.get("key2")) |
| self.assertEqual(b"21", fs.get("key3")) |
| self.assertTrue(fs.check(["key3"])) |
| self.assertFalse(fs.check(["Randomkey3"])) |
| |
| fs.set("-key3", "7") |
| self.assertEqual(b"7", fs.get("-key3")) |
| fs.delete_key("-key3") |
| self.assertEqual(fs.num_keys(), self.num_keys_total) |
| |
| def test_set_get_check(self): |
| self._test_set_get_check(self._create_store()) |
| |
| def _test_compare_set(self, store): |
| missing_key_result = store.compare_set( |
| "cs_key0", "wrong_old_value", "new_value0" |
| ) |
| self.assertEqual(b"wrong_old_value", missing_key_result) |
| |
| store.set("cs_key0", "value0") |
| self.assertEqual(b"value0", store.get("cs_key0")) |
| old_value_result = store.compare_set("cs_key0", "wrong_old_value", "new_value0") |
| self.assertEqual(b"value0", old_value_result) |
| self.assertEqual(b"value0", store.get("cs_key0")) |
| new_value_result = store.compare_set("cs_key0", "value0", "new_value0") |
| self.assertEqual(b"new_value0", new_value_result) |
| self.assertEqual(b"new_value0", store.get("cs_key0")) |
| empty_old_value_result = store.compare_set("cs_key1", "", "new_value1") |
| self.assertEqual(b"new_value1", empty_old_value_result) |
| self.assertEqual(b"new_value1", store.get("cs_key1")) |
| |
| def test_compare_set(self): |
| self._test_compare_set(self._create_store()) |
| |
| def _test_simple_wait(self, fs): |
| with self.assertRaisesRegex(RuntimeError, "[t -i]imeout"): |
| fs.wait(["bad_key"], timedelta(seconds=0.25)) |
| fs.add("good_key", 1) |
| fs.wait(["good_key"]) |
| |
| def test_simple_wait(self): |
| self._test_simple_wait(self._create_store()) |
| |
| def _test_append(self, store): |
| if not store.has_extended_api(): |
| self.skipTest("Store doesn't support extended APIs") |
| store.set("foo", "po") |
| store.append("foo", "tato") |
| store.append("bar", "po") |
| store.append("bar", "tato") |
| self.assertEqual(b"potato", store.get("foo")) |
| self.assertEqual(b"potato", store.get("bar")) |
| |
| def test_append(self): |
| self._test_append(self._create_store()) |
| |
| def _test_multi_set(self, store): |
| if not store.has_extended_api(): |
| self.skipTest("Store doesn't support extended APIs") |
| store.multi_set(["foo", "bar"], ["po", "tato"]) |
| self.assertEqual(b"po", store.get("foo")) |
| self.assertEqual(b"tato", store.get("bar")) |
| |
| def test_multi_set(self): |
| self._test_multi_set(self._create_store()) |
| |
| def _test_multi_get(self, store): |
| if not store.has_extended_api(): |
| self.skipTest("Store doesn't support extended APIs") |
| store.set("foo", "po") |
| store.set("bar", "tato") |
| v0, v1 = store.multi_get(["foo", "bar"]) |
| self.assertEqual(b"po", v0) |
| self.assertEqual(b"tato", v1) |
| |
| def test_multi_get(self): |
| self._test_multi_get(self._create_store()) |
| |
| # This is the number of keys used in test_set_get. Adding this as a class |
| # property instead of hardcoding in the test since some Store |
| # implementations will have differing number of keys. In the base case, |
| # there will be 5 keys: key, key0, key1, key2, key3. |
| @property |
| def num_keys_total(self): |
| return 5 |
| |
| |
| class FileStoreTest(TestCase, StoreTestBase): |
| def setUp(self): |
| super().setUp() |
| self.file = tempfile.NamedTemporaryFile(delete=False) |
| |
| def _create_store(self): |
| store = dist.FileStore(self.file.name, 1) |
| store.set_timeout(timedelta(seconds=300)) |
| return store |
| |
| def test_init_pg_and_rpc_with_same_file(self): |
| file = tempfile.NamedTemporaryFile(delete=False) |
| # Init RPC using file |
| rpc_backend_options = rpc.TensorPipeRpcBackendOptions() |
| rpc_backend_options.init_method = f"file://{file.name}" |
| rpc_backend_options._transports = tp_transports() |
| rpc.init_rpc( |
| "worker", rank=0, world_size=1, rpc_backend_options=rpc_backend_options |
| ) |
| |
| # Init PG using file |
| dist.init_process_group( |
| "gloo", rank=0, world_size=1, init_method=f"file://{file.name}" |
| ) |
| dist.destroy_process_group() |
| assert os.path.exists(file.name) |
| |
| rpc.shutdown() |
| os.remove(file.name) |
| |
| def test_refcount(self): |
| file = tempfile.NamedTemporaryFile(delete=False) |
| store = dist.FileStore(file.name, 1) |
| store2 = dist.FileStore(file.name, 1) |
| |
| del store |
| assert os.path.exists(file.name) |
| del store2 |
| assert not os.path.exists(file.name) |
| |
| @property |
| def num_keys_total(self): |
| return 6 |
| |
| |
| @skip_if_win32() |
| class HashStoreTest(TestCase, StoreTestBase): |
| def _create_store(self): |
| store = dist.HashStore() |
| store.set_timeout(timedelta(seconds=300)) |
| return store |
| |
| |
| class PrefixStoreTest(TestCase): |
| def setUp(self): |
| # delete is false as FileStore will automatically clean up the file |
| self.file = tempfile.NamedTemporaryFile(delete=False) |
| |
| def test_get_underlying_store(self): |
| tcp_store = dist.TCPStore( |
| host_name=DEFAULT_HOSTNAME, port=0, world_size=1, is_master=True |
| ) |
| hash_store = dist.HashStore() |
| file_store = dist.FileStore(self.file.name, world_size=1) |
| for store in [tcp_store, hash_store, file_store]: |
| with self.subTest(f"Testing getting underlying_store for {type(store)}"): |
| prefix_store = dist.PrefixStore("prefix", store) |
| self.assertEqual(prefix_store.underlying_store, store) |
| |
| |
| class PrefixFileStoreTest(TestCase, StoreTestBase): |
| def setUp(self): |
| super().setUp() |
| self.file = tempfile.NamedTemporaryFile(delete=False) |
| self.filestore = dist.FileStore(self.file.name, 1) |
| self.prefix = "test_prefix" |
| self.filestore.set_timeout(timedelta(seconds=300)) |
| |
| def _create_store(self): |
| return dist.PrefixStore(self.prefix, self.filestore) |
| |
| @property |
| def num_keys_total(self): |
| return 6 |
| |
| |
| class TCPStoreTest(TestCase, StoreTestBase): |
| def _create_store(self): |
| store = create_tcp_store() |
| store.set_timeout(timedelta(seconds=300)) |
| return store |
| |
| def _create_store_with_ws(self, addr, world_size): |
| return create_tcp_store(addr, world_size, wait_for_workers=False) |
| |
| def test_address_already_in_use(self): |
| err_msg_reg = "^The server socket has failed to listen on any local " |
| with self.assertRaisesRegex(RuntimeError, err_msg_reg): |
| addr = DEFAULT_HOSTNAME |
| port = common.find_free_port() |
| |
| # Use noqa to silence flake8. |
| # Need to store in an unused variable here to ensure the first |
| # object is not destroyed before the second object is created. |
| store1 = dist.TCPStore(addr, port, 1, True) # noqa: F841 |
| store2 = dist.TCPStore(addr, port, 1, True) # noqa: F841 |
| |
| @retry_on_connect_failures |
| def test_multitenancy(self): |
| addr = DEFAULT_HOSTNAME |
| port = common.find_free_port() |
| |
| # Use noqa to silence flake8. |
| # Need to store in an unused variable here to ensure the first |
| # object is not destroyed before the second object is created. |
| store1 = dist.TCPStore(addr, port, 1, True, multi_tenant=True) # type: ignore[call-arg] # noqa: F841 |
| store2 = dist.TCPStore(addr, port, 1, True, multi_tenant=True) # type: ignore[call-arg] # noqa: F841 |
| |
| @skip_if_win32() |
| @retry_on_connect_failures |
| def test_init_pg_and_rpc_with_same_socket(self): |
| addr = DEFAULT_HOSTNAME |
| port = common.find_free_port() |
| |
| os.environ["MASTER_ADDR"] = addr |
| os.environ["MASTER_PORT"] = str(port) |
| |
| # We internally use a multi-tenant TCP store. Both PG and RPC should successfully |
| # initialize even when using the same socket address. |
| |
| dist.init_process_group( |
| backend="gloo", |
| init_method="env://", |
| rank=0, |
| world_size=1, |
| ) |
| |
| backend_opts = rpc.TensorPipeRpcBackendOptions( |
| init_method=f"tcp://{addr}:{port}", _transports=tp_transports() |
| ) |
| rpc.init_rpc( |
| name="worker0", |
| rank=0, |
| world_size=1, |
| rpc_backend_options=backend_opts, |
| ) |
| |
| rpc.shutdown() |
| |
| @skip_if_win32() |
| def test_take_over_listen_socket(self): |
| listen_sock: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| listen_sock.bind(("localhost", 0)) |
| addr, port, *_ = listen_sock.getsockname() |
| listen_fd = listen_sock.detach() |
| |
| store = dist.TCPStore(addr, port, 1, is_master=True, master_listen_fd=listen_fd) |
| |
| store.set("key", "value") |
| self.assertEqual(b"value", store.get("key")) |
| |
| # The TCPStore has 6 keys in test_set_get. It contains the 5 keys added by |
| # the user and one additional key used for coordinate all the workers. |
| @property |
| def num_keys_total(self): |
| return 6 |
| |
| def _test_numkeys_delkeys(self, fs): |
| # We start off with one init key in the store to coordinate workers |
| self.assertEqual(fs.num_keys(), 1) |
| fs.add("key", 1) |
| fs.add("key", 2) |
| fs.add("key", 3) |
| fs.set("key0", "value0") |
| fs.add("key3", 1) |
| fs.set("key1", "value1") |
| self.assertEqual(fs.num_keys(), 5) |
| fs.delete_key("key") |
| self.assertEqual(fs.num_keys(), 4) |
| fs.set_timeout(timedelta(seconds=2)) |
| with self.assertRaises(RuntimeError): |
| fs.get("key") |
| fs.delete_key("key0") |
| fs.delete_key("key3") |
| self.assertEqual(fs.num_keys(), 2) |
| fs.set("key4", "value2") |
| self.assertEqual(fs.num_keys(), 3) |
| self.assertEqual(b"value1", fs.get("key1")) |
| self.assertEqual(b"value2", fs.get("key4")) |
| |
| def test_numkeys_delkeys(self): |
| self._test_numkeys_delkeys(self._create_store()) |
| |
| def _create_client(self, index, addr, port, world_size): |
| client_store = dist.TCPStore( |
| addr, port, world_size=world_size, timeout=timedelta(seconds=10) |
| ) |
| self.assertEqual(b"value", client_store.get("key")) |
| client_store.set(f"new_key{index}", f"new_value{index}") |
| self.assertEqual( |
| f"next_value{index}".encode(), |
| client_store.compare_set( |
| f"new_key{index}", f"new_value{index}", f"next_value{index}" |
| ), |
| ) |
| |
| def _multi_worker_helper(self, world_size): |
| addr = DEFAULT_HOSTNAME |
| server_store = self._create_store_with_ws(addr, world_size) |
| server_store.set("key", "value") |
| port = server_store.port |
| |
| num_indices = world_size if world_size else 1 |
| for i in range(num_indices): |
| self._create_client(i, addr, port, world_size) |
| |
| def test_multi_worker_with_fixed_world_size(self): |
| self._multi_worker_helper(5) |
| |
| def test_multi_worker_with_nonfixed_world_size(self): |
| self._multi_worker_helper(None) |
| |
| def test_append(self): |
| store = self._create_store() |
| store.set("foo", "po") |
| store.append("foo", "tato") |
| store.append("bar", "po") |
| store.append("bar", "tato") |
| self.assertEqual(b"potato", store.get("foo")) |
| self.assertEqual(b"potato", store.get("bar")) |
| |
| def test_multi_set(self): |
| store = self._create_store() |
| store.multi_set(["foo", "bar"], ["po", "tato"]) |
| self.assertEqual(b"po", store.get("foo")) |
| self.assertEqual(b"tato", store.get("bar")) |
| |
| def test_multi_get(self): |
| store = self._create_store() |
| store.set("foo", "po") |
| store.set("bar", "tato") |
| v0, v1 = store.multi_get(["foo", "bar"]) |
| self.assertEqual(b"po", v0) |
| self.assertEqual(b"tato", v1) |
| |
| def test_store_timeout_on_missing_clients(self): |
| with self.assertRaisesRegex( |
| DistStoreError, |
| r"Timed out after \d+ seconds waiting for clients. \d+/\d+ clients joined.", |
| ): |
| # world_size is 2 so it should timeout |
| dist.TCPStore("localhost", 0, 2, True, timeout=timedelta(seconds=2)) |
| |
| # when wait_for_workers is not set, then there should be no exception raised |
| dist.TCPStore( |
| "localhost", |
| 0, |
| 2, |
| True, |
| timeout=timedelta(seconds=2), |
| wait_for_workers=False, |
| ) |
| |
| |
| class LibUvTCPStoreTest(TCPStoreTest): |
| def _create_store(self): |
| store = create_tcp_store(use_libuv=True) |
| store.set_timeout(timedelta(seconds=300)) |
| return store |
| |
| def _create_store_with_ws(self, addr, world_size): |
| return create_tcp_store( |
| addr, world_size, wait_for_workers=False, use_libuv=True |
| ) |
| |
| |
| class PrefixTCPStoreTest(TestCase, StoreTestBase): |
| def setUp(self): |
| super().setUp() |
| self.tcpstore = create_tcp_store() |
| self.prefix = "test_prefix" |
| self.tcpstore.set_timeout(timedelta(seconds=300)) |
| |
| def _create_store(self): |
| return dist.PrefixStore(self.prefix, self.tcpstore) |
| |
| # The PrefixTCPStore has 6 keys in test_set_get. It contains the 5 keys |
| # added by the user and one additional key used for coordinate all the |
| # workers. |
| @property |
| def num_keys_total(self): |
| return 6 |
| |
| def test_underlying_non_prefix_store(self): |
| store = self._create_store() |
| wrapped_store = dist.PrefixStore( |
| self.prefix, dist.PrefixStore(self.prefix, store) |
| ) |
| self.assertEqual(self.tcpstore, store._underlying_non_prefix_store) |
| self.assertEqual(self.tcpstore, wrapped_store._underlying_non_prefix_store) |
| |
| |
| class MyPythonStore(dist.Store): |
| def __init__(self): |
| super().__init__() |
| self.store = {} |
| |
| def set(self, key, value): |
| if not isinstance(key, (str, bytes)): |
| raise AssertionError("Expected set to be called with string key") |
| if type(value) is not bytes: |
| raise AssertionError("Expected set to be called with bytes value") |
| self.store[key] = value |
| |
| def get(self, key): |
| value = self.store.get(key, b"") |
| if type(value) is not bytes: |
| raise AssertionError("Expected get to return bytes value") |
| return value |
| |
| def add(self, key, value): |
| new = int(self.store.get(key, 0)) + value |
| self.set(key, bytes(str(new).encode("utf-8"))) |
| return new |
| |
| def compare_set(self, key, expected, newValue): |
| if type(expected) is not bytes: |
| raise AssertionError("compare_set::expected not bytes") |
| if type(newValue) is not bytes: |
| raise AssertionError("compare_set::newValue not bytes") |
| |
| val = self.store.get(key, None) |
| if expected == val or val is None: |
| val = self.store[key] = newValue |
| return val |
| |
| |
| class PythonStoreTest(TestCase): |
| def test_set_get(self): |
| # If we were to inherit from StoreTestBase and try to use |
| # its test_set_get function, we would exercise the Python |
| # API directly, instead of going through the C++ trampoline. |
| # We care about testing the C++ trampoline, so run the |
| # equivalent of StoreTestBase.test_set_get from C++. |
| # See `torch/csrc/distributed/c10d/init.cpp` for the definition |
| # of this test function. |
| dist._test_python_store(MyPythonStore()) |
| |
| |
| class RendezvousTest(TestCase): |
| def test_unknown_handler(self): |
| with self.assertRaisesRegex(RuntimeError, "^No rendezvous handler"): |
| dist.rendezvous("invalid://") |
| |
| def test_url_with_node_params(self): |
| with self.assertRaisesRegex(AssertionError, "has node-specific arguments"): |
| dist.rendezvous("file://foo?rank=12&world_size=16", 12, 16) |
| |
| |
| class RendezvousEnvTest(TestCase): |
| @retry_on_connect_failures |
| def test_nominal(self): |
| os.environ["WORLD_SIZE"] = "1" |
| os.environ["MASTER_ADDR"] = "127.0.0.1" |
| os.environ["MASTER_PORT"] = str(common.find_free_port()) |
| |
| # Single rank |
| os.environ["RANK"] = "0" |
| gen0 = dist.rendezvous("env://") |
| store0, rank0, size0 = next(gen0) |
| self.assertEqual(0, rank0) |
| self.assertEqual(1, size0) |
| |
| store0.set("key0", "value0") |
| |
| # check with get |
| self.assertEqual(b"value0", store0.get("key0")) |
| |
| |
| class RendezvousFileTest(TestCase): |
| def test_common_errors(self): |
| with self.assertRaisesRegex(ValueError, "path missing"): |
| gen = dist.rendezvous("file://?rank=0&world_size=1") |
| next(gen) |
| with self.assertRaisesRegex(ValueError, "rank parameter missing"): |
| gen = dist.rendezvous("file:///tmp/foo?world_size=1") |
| next(gen) |
| with self.assertRaisesRegex(ValueError, "size parameter missing"): |
| gen = dist.rendezvous("file:///tmp/foo?rank=0") |
| next(gen) |
| |
| def test_nominal(self): |
| with tempfile.NamedTemporaryFile(delete=False) as file: |
| url = f'file:///{file.name.replace(os.path.sep, "/")}?world_size=2' |
| gen0 = dist.rendezvous(url + "&rank=0") |
| store0, rank0, size0 = next(gen0) |
| self.assertEqual(0, rank0) |
| self.assertEqual(2, size0) |
| gen1 = dist.rendezvous(url + "&rank=1") |
| store1, rank1, size1 = next(gen1) |
| self.assertEqual(1, rank1) |
| self.assertEqual(2, size1) |
| |
| # Set value on both stores |
| store0.set("key0", "value0") |
| store1.set("key1", "value1") |
| |
| # Cross check with get |
| self.assertEqual(b"value0", store1.get("key0")) |
| self.assertEqual(b"value1", store0.get("key1")) |
| |
| |
| @skip_if_win32() |
| class RendezvousTCPTest(TestCase): |
| def create_tcp_url(self): |
| addr = DEFAULT_HOSTNAME |
| port = common.find_free_port() |
| url = "tcp://%s:%d?world_size=%d" % (addr, port, 1) |
| return url |
| |
| def test_common_errors(self): |
| with self.assertRaisesRegex(ValueError, "port number missing"): |
| gen = dist.rendezvous("tcp://127.0.0.1?rank=0&world_size=1") |
| next(gen) |
| with self.assertRaisesRegex(ValueError, "rank parameter missing"): |
| gen = dist.rendezvous("tcp://127.0.0.1:23456?world_size=1") |
| next(gen) |
| with self.assertRaisesRegex(ValueError, "size parameter missing"): |
| gen = dist.rendezvous("tcp://127.0.0.1:23456?rank=0") |
| next(gen) |
| |
| def test_dns_timeout(self): |
| with self.assertRaisesRegex( |
| DistNetworkError, "client socket has timed out after.*dnsnotexist" |
| ) as manager: |
| gen = dist.rendezvous( |
| "tcp://dnsnotexist:23456?world_size=2&rank=0", |
| timeout=timedelta(seconds=1), |
| ) |
| next(gen) |
| self.assertTrue(isinstance(manager.exception, DistError)) |
| |
| @retry_on_connect_failures |
| def test_nominal(self): |
| url = self.create_tcp_url() |
| gen0 = dist.rendezvous(url + "&rank=0") |
| store0, rank0, size0 = next(gen0) |
| self.assertEqual(0, rank0) |
| self.assertEqual(1, size0) |
| |
| # Set value on the single store |
| store0.set("key0", "value0") |
| |
| # check with get |
| self.assertEqual(b"value0", store0.get("key0")) |
| |
| @retry_on_connect_failures(connect_errors=(CONNECT_TIMEOUT, ADDRESS_IN_USE)) |
| def test_tcp_store_timeout_set(self): |
| url = self.create_tcp_url() |
| test_store_timeout = timedelta(seconds=10) |
| gen0 = dist.rendezvous(url + "&rank=0", timeout=test_store_timeout) |
| store0, rank0, size0 = next(gen0) |
| # this should time out in 10s. If the timeout passed into rendezvous was |
| # not respected, it will take much longer to timeout. |
| start = time.time() |
| with self.assertRaisesRegex(RuntimeError, "Timeout"): |
| store0.get("nonexistant key") |
| |
| end = time.time() |
| time_diff = end - start |
| self.assertGreater(test_store_timeout.seconds * 10, time_diff) |
| |
| def test_tcp_store_timeout_doest_break_client(self): |
| url = self.create_tcp_url() |
| test_store_timeout = timedelta(seconds=10) |
| gen0 = dist.rendezvous(url + "&rank=0", timeout=test_store_timeout) |
| store0, rank0, size0 = next(gen0) |
| # this should time out in 10s. If the timeout passed into rendezvous was |
| # not respected, it will take much longer to timeout. |
| start = time.time() |
| with self.assertRaisesRegex(RuntimeError, "Timeout"): |
| store0.get("the_key") |
| |
| store0.set("the_key", "x") |
| |
| self.assertEqual(b"x", store0.get("the_key")) |
| |
| end = time.time() |
| time_diff = end - start |
| self.assertGreater(test_store_timeout.seconds * 10, time_diff) |
| |
| def test_tcp_store_url_with_libuv(self): |
| url = self.create_tcp_url() |
| gen0 = dist.rendezvous(url + "&rank=0&use_libuv=1") |
| store0, rank0, size0 = next(gen0) |
| self.assertTrue(store0.libuvBackend) |
| |
| |
| class DummyStore(dist.Store): |
| def __init__(self): |
| self.appends = [] |
| self.multi_sets = [] |
| self.multi_gets = [] |
| self.multi_get_res = [] |
| super().__init__() |
| |
| def append(self, key, value): |
| self.appends.append((key, value)) |
| |
| def multi_get(self, keys): |
| self.multi_gets.append(keys) |
| return self.multi_get_res.pop(0) |
| |
| def multi_set(self, keys, values): |
| self.multi_sets.append((keys, values)) |
| |
| def has_extended_api(self): |
| return True |
| |
| |
| class TestPythonStore(TestCase): |
| def test_optional_methods_fail(self): |
| class TestStore(dist.Store): |
| pass |
| |
| store = TestStore() |
| self.assertFalse(store.has_extended_api()) |
| with self.assertRaisesRegex(RuntimeError, "Not implemented."): |
| store.append("foo", "bar") |
| with self.assertRaisesRegex(RuntimeError, "Not implemented."): |
| store.multi_get(["foo", "bar"]) |
| with self.assertRaisesRegex(RuntimeError, "Not implemented."): |
| store.multi_set(["foo", "bar"], [b"v", b"v"]) |
| |
| def test_has_extended_api_passthrough(self): |
| class TestStore(dist.Store): |
| pass |
| |
| test_store = TestStore() |
| store = dist.PrefixStore("p", test_store) |
| self.assertFalse(store.has_extended_api()) |
| with self.assertRaisesRegex(RuntimeError, "Not implemented."): |
| store.append("foo", "bar") |
| with self.assertRaisesRegex(RuntimeError, "Not implemented."): |
| store.multi_get(["foo", "bar"]) |
| with self.assertRaisesRegex(RuntimeError, "Not implemented."): |
| store.multi_set(["foo", "bar"], [b"v", b"v"]) |
| |
| def test_has_extended_api_roundtrip(self): |
| store = DummyStore() |
| prefix = dist.PrefixStore("p", store) |
| self.assertTrue(prefix.has_extended_api()) |
| |
| def test_append_roundtrip(self): |
| store = DummyStore() |
| prefix = dist.PrefixStore("p", store) |
| prefix.append("foo", "bar") |
| self.assertEqual(1, len(store.appends)) |
| self.assertEqual(("p/foo", b"bar"), store.appends[0]) |
| |
| def test_multi_get_roundtrip(self): |
| store = DummyStore() |
| prefix = dist.PrefixStore("p", store) |
| store.multi_get_res.append([b"x", b"y"]) |
| res = prefix.multi_get(["foo", "bar"]) |
| self.assertEqual(1, len(store.multi_gets)) |
| self.assertEqual(["p/foo", "p/bar"], store.multi_gets[0]) |
| self.assertEqual([b"x", b"y"], res) |
| |
| def test_multi_set_roundtrip(self): |
| store = DummyStore() |
| prefix = dist.PrefixStore("p", store) |
| prefix.multi_set(["foo", "bar"], [b"x", b"y"]) |
| self.assertEqual(1, len(store.multi_sets)) |
| self.assertEqual(["p/foo", "p/bar"], store.multi_sets[0][0]) |
| self.assertEqual([b"x", b"y"], store.multi_sets[0][1]) |
| |
| def test_extended_methods_fallbacks(self): |
| test_store = MyPythonStore() |
| store = dist.PrefixStore("p", test_store) |
| self.assertFalse(store.has_extended_api()) |
| store.append("foo", b"po") |
| store.append("foo", b"tato") |
| self.assertEqual(store.get("foo"), b"potato") |
| |
| store.multi_set(["a", "b"], [b"c", b"d"]) |
| self.assertEqual(store.multi_get(["a", "b", "foo"]), [b"c", b"d", b"potato"]) |
| |
| |
| class TestMultiThreadedWait(MultiThreadedTestCase): |
| # TODO: Use less hacky means of instantiating stores. |
| # Note, stores accumulate values per test. |
| stores = [ |
| dist.FileStore(tempfile.NamedTemporaryFile(delete=False).name, 1), |
| dist.HashStore(), |
| dist.PrefixStore( |
| "pre", dist.FileStore(tempfile.NamedTemporaryFile(delete=False).name, 1) |
| ), |
| create_tcp_store(), |
| create_tcp_store(use_libuv=True), |
| dist.PrefixStore("pre", create_tcp_store()), |
| dist.PrefixStore("pre", create_tcp_store(use_libuv=True)), |
| ] |
| |
| @property |
| def world_size(self): |
| return 2 |
| |
| def setUp(self): |
| super().setUp() |
| self._spawn_threads() |
| |
| # Iterates over self.stores, keep 7 in sync with len(self.stores). |
| @parametrize("i", range(7)) |
| def test_wait(self, i): |
| store = self.stores[i] |
| store.set_timeout(timedelta(seconds=2)) |
| if dist.get_rank() == 0: |
| store.wait(["key1"]) |
| self.assertEqual(b"value1", store.get("key1")) |
| if dist.get_rank() == 1: |
| store.set("key1", "value1") |
| |
| |
| instantiate_parametrized_tests(TestMultiThreadedWait) |
| |
| |
| @skip_if_win32() |
| class TimeoutTest(TestCase): |
| def tearDown(self): |
| import signal |
| |
| super().tearDown() |
| signal.signal(signal.SIGUSR1, signal.SIG_IGN) |
| |
| def test_interrupt_doesnt_break_wait(self): |
| import signal |
| |
| rank_res = [None, None] |
| |
| def run(rank, my_store): |
| nonlocal rank_res |
| try: |
| if rank == 0: |
| time.sleep(4) |
| my_store.set("foo", "bar") |
| else: |
| my_store.wait(["foo"], datetime.timedelta(seconds=10)) |
| rank_res[rank] = True |
| except Error as e: # noqa: F821 |
| rank_res[rank] = e |
| time.sleep(1) |
| |
| rank0_store = dist.TCPStore( |
| host_name=DEFAULT_HOSTNAME, |
| port=0, |
| world_size=2, |
| is_master=True, |
| wait_for_workers=False, |
| ) |
| rank1_store = dist.TCPStore( |
| host_name=DEFAULT_HOSTNAME, |
| port=rank0_store.port, |
| world_size=2, |
| is_master=False, |
| wait_for_workers=False, |
| ) |
| |
| ths = [] |
| for i in range(2): |
| t = threading.Thread( |
| target=run, |
| args=( |
| i, |
| [rank0_store, rank1_store][i], |
| ), |
| ) |
| t.start() |
| ths.append(t) |
| |
| def handler(a, b): |
| pass |
| |
| signal.signal(signal.SIGUSR1, handler) |
| time.sleep(1) |
| signal.pthread_kill(ths[1].ident, signal.SIGUSR1) |
| |
| for t in ths: |
| t.join() |
| self.assertTrue(rank_res[0], "rank0") |
| self.assertTrue(rank_res[1], "rank1") |
| |
| |
| class InitPgWithUvStore(TestCase): |
| def tearDown(self): |
| super().tearDown() |
| os.environ.pop("USE_LIBUV", None) |
| os.environ.pop("MASTER_ADDR", None) |
| os.environ.pop("MASTER_PORT", None) |
| |
| def test_with_url_param(self): |
| port = common.find_free_port() |
| dist.init_process_group( |
| "gloo", |
| rank=0, |
| world_size=1, |
| init_method=f"tcp://{DEFAULT_HOSTNAME}:{port}?use_libuv=1", |
| ) |
| self._run_test() |
| |
| def test_with_env_var(self): |
| port = common.find_free_port() |
| os.environ["USE_LIBUV"] = "1" |
| os.environ["MASTER_ADDR"] = DEFAULT_HOSTNAME |
| os.environ["MASTER_PORT"] = str(port) |
| dist.init_process_group("gloo", rank=0, world_size=1, init_method="env://") |
| self._run_test() |
| |
| def _run_test(self): |
| pg = dist.group.WORLD |
| store = c10d._get_process_group_store(pg) |
| self.assertTrue(isinstance(store, dist.PrefixStore)) |
| # c10d does multiple levels of wrapping |
| while isinstance(store, dist.PrefixStore): |
| store = store.underlying_store |
| self.assertTrue(isinstance(store, dist.TCPStore)) |
| self.assertTrue(store.libuvBackend) |
| dist.destroy_process_group() |
| |
| |
| if __name__ == "__main__": |
| assert ( |
| not torch.cuda._initialized |
| ), "test_distributed must not have initialized CUDA context on main process" |
| |
| run_tests() |