| from __future__ import absolute_import, division, print_function, unicode_literals |
| |
| import concurrent.futures |
| import sys |
| import unittest |
| from collections import namedtuple |
| from unittest import mock |
| |
| import torch |
| import torch.distributed as dist |
| import torch.distributed.rpc as rpc |
| from common_utils import load_tests |
| from dist_utils import INIT_METHOD_TEMPLATE, TEST_CONFIG, dist_init |
| from torch.distributed.rpc import RpcBackend |
| from torch.distributed.rpc.internal import PythonUDF, _internal_rpc_pickler |
| |
| |
| def requires_process_group_agent(message=""): |
| def decorator(old_func): |
| return unittest.skipUnless( |
| TEST_CONFIG.rpc_backend == RpcBackend.PROCESS_GROUP, |
| message, |
| )(old_func) |
| return decorator |
| |
| |
| VALUE_FUTURE = concurrent.futures.Future() |
| |
| |
| def stub_start_rpc_backend_handler(store, self_name, self_rank, worker_name_to_id): |
| return mock.Mock() # RpcAgent. |
| |
| |
| def set_value(value): |
| VALUE_FUTURE.set_result(value) |
| |
| |
| # it is used to test python user defined function over rpc |
| # classes and functions are used to test python user defined class and |
| # methods over rpc |
| TensorClass = namedtuple("TensorClass", ["tensors"]) |
| |
| |
| class MyPickleClass: |
| def __init__(self): |
| self.t = None |
| |
| def __getstate__(self): |
| (pickled_python_udf, tensors) = _internal_rpc_pickler.serialize( |
| PythonUDF(my_tensor_function, (torch.ones(2, 2), torch.ones(2, 2)), None) |
| ) |
| return (pickled_python_udf, tensors) |
| |
| def __setstate__(self, obj): |
| python_udf = _internal_rpc_pickler.deserialize(obj[0], obj[1]) |
| result = python_udf.func(python_udf.args[0], python_udf.args[1]) |
| self.t = result |
| |
| def set(self, val): |
| self.t = val |
| |
| |
| class MyClass: |
| def __init__(self, a): |
| self.a = a |
| |
| def my_instance_method(self, b): |
| return self.a + b |
| |
| @classmethod |
| def my_class_method(cls, d, e): |
| return d + e |
| |
| @staticmethod |
| def my_static_method(f): |
| return f > 10 |
| |
| |
| def run_nested_pickle(pickle_cls_instance, tensor): |
| return pickle_cls_instance.t + tensor |
| |
| |
| def build_complex_tensors(): |
| a = torch.ones(3, 3) |
| b = [a, a] |
| c = [b, b] |
| d = [a, b] |
| e = {a: d} |
| return [a, b, c, d, e] |
| |
| |
| def my_function(a, b, c): |
| return a + b + c |
| |
| |
| def my_tensor_function(a, b): |
| return a + b |
| |
| |
| def my_complex_tensor_function(list_input, tensor_class_input, dict_input): |
| res = list_input[0] |
| for t in list_input: |
| res += t |
| for k, v in dict_input.items(): |
| res += v |
| complex_tensors = tensor_class_input.tensors |
| return (res, complex_tensors[0], complex_tensors[1], complex_tensors[2]) |
| |
| |
| def my_rref_function(rref_a, rref_b): |
| return rref_a.to_here().wait() + rref_b.to_here().wait() |
| |
| |
| def no_result(): |
| print("do nothing") |
| |
| |
| def nested_rpc(dst): |
| return rpc.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1)) |
| |
| |
| def multi_layer_nested_async_rpc(dst, world_size, ttl): |
| # this method returns immediately without blocking the callee, but will |
| # generate additional requests. |
| if ttl > 0: |
| current_dst = "worker{}".format(dst) |
| next_dst = (dst + 1) % world_size |
| rpc.rpc_async( |
| current_dst, |
| multi_layer_nested_async_rpc, |
| args=(next_dst, world_size, ttl - 1), |
| ) |
| return 0 |
| |
| |
| def nested_rref(dst): |
| return ( |
| rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1)), |
| rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 2)), |
| ) |
| |
| |
| def nested_remote(dst): |
| rref = rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 3)) |
| return rref.to_here().wait() |
| |
| |
| def rref_forward_chain(dst, world_size, rref, ttl): |
| if ttl > 0: |
| current_dst = "worker{}".format(dst) |
| next_dst = (dst + 1) % world_size |
| ret_rref = rpc.remote( |
| current_dst, rref_forward_chain, args=(next_dst, world_size, rref, ttl - 1) |
| ) |
| return [ret_rref] |
| else: |
| return rref.to_here().wait() |
| |
| |
| def rpc_return_rref(dst): |
| return rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1)) |
| |
| |
| def light_rpc(): |
| return 0 |
| |
| |
| def heavy_rpc(tensor): |
| for i in range(1, 100): |
| tensor *= i |
| tensor /= i + 1 |
| return 0 |
| |
| |
| def raise_func(): |
| raise ValueError("Expected error") |
| |
| |
| # load_tests from common_utils is used to automatically filter tests for |
| # sharding on sandcastle. This line silences flake warnings |
| load_tests = load_tests |
| |
| |
| @unittest.skipIf( |
| sys.version_info < (3, 0), |
| "Pytorch distributed rpc package " "does not support python2", |
| ) |
| class RpcTest(object): |
| @property |
| def world_size(self): |
| return 4 |
| |
| @property |
| def init_method(self): |
| return INIT_METHOD_TEMPLATE.format(file_name=self.file_name) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_worker_id(self): |
| n = self.rank + 1 |
| peer_rank = n % self.world_size |
| self_worker_info = rpc.get_worker_info() |
| peer_worker_info = rpc.get_worker_info("worker{}".format(peer_rank)) |
| |
| self.assertEqual(self_worker_info.name, "worker{}".format(self.rank)) |
| self.assertEqual(peer_worker_info.name, "worker{}".format(peer_rank)) |
| |
| with self.assertRaisesRegex(RuntimeError, "Unknown destination worker"): |
| unknown_worker_id = rpc.get_worker_info("WorkerUnknown") |
| |
| @dist_init(setup_model_parallel=True) |
| def test_self_add(self): |
| self_worker_info = rpc.get_worker_info() |
| self_worker_name = "worker{}".format(self.rank) |
| |
| with self.assertRaisesRegex( |
| RuntimeError, "does not support making RPC calls to self" |
| ): |
| rpc.rpc_sync(self_worker_info, torch.add, args=(torch.ones(2, 2), 1)) |
| |
| with self.assertRaisesRegex( |
| RuntimeError, "does not support making RPC calls to self" |
| ): |
| rpc.rpc_sync(self_worker_name, torch.add, args=(torch.ones(2, 2), 1)) |
| |
| @mock.patch.object(torch.distributed.autograd, "_init") |
| @mock.patch.object(torch.distributed.rpc.api, "_start_rpc_agent") |
| @dist_init(setup_model_parallel=False) |
| def test_register_rpc_backend_and_start_rpc_backend( |
| self, mock_rpc_agent, mock_dist_autograd_init |
| ): |
| backend_name = "stub_backend" |
| rpc.register_backend( |
| backend_name, stub_start_rpc_backend_handler |
| ) |
| rpc.init_model_parallel( |
| self_name="worker1", |
| backend=backend_name, |
| init_method=self.init_method, |
| self_rank=self.rank, |
| worker_name_to_id=self.worker_name_to_id, |
| ) |
| |
| @requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip") |
| @dist_init(setup_model_parallel=False) |
| def test_duplicate_name(self): |
| dist.init_process_group( |
| backend=dist.Backend.GLOO, |
| init_method=self.init_method, |
| rank=self.rank, |
| world_size=self.world_size, |
| ) |
| with self.assertRaisesRegex(RuntimeError, "is not unique"): |
| rpc.init_model_parallel( |
| self_name="duplicate_name", |
| backend=TEST_CONFIG.rpc_backend, |
| init_method=self.init_method, |
| self_rank=self.rank, |
| worker_name_to_id=self.worker_name_to_id, |
| ) |
| rpc.join_rpc() |
| |
| @dist_init(setup_model_parallel=False) |
| def test_reinit(self): |
| dist.init_process_group( |
| backend=dist.Backend.GLOO, |
| init_method=self.init_method, |
| rank=self.rank, |
| world_size=self.world_size, |
| ) |
| rpc.init_model_parallel( |
| self_name="worker{}".format(self.rank), |
| backend=TEST_CONFIG.rpc_backend, |
| init_method=self.init_method, |
| self_rank=self.rank, |
| worker_name_to_id=self.worker_name_to_id, |
| ) |
| with self.assertRaisesRegex(RuntimeError, "is already initialized"): |
| rpc.init_model_parallel( |
| self_name="worker{}".format(self.rank), |
| backend=TEST_CONFIG.rpc_backend, |
| init_method=self.init_method, |
| self_rank=self.rank, |
| worker_name_to_id=self.worker_name_to_id, |
| ) |
| rpc.join_rpc() |
| |
| @dist_init(setup_model_parallel=False) |
| def test_init_invalid_backend(self): |
| with self.assertRaisesRegex(RuntimeError, "Unrecognized RPC backend"): |
| rpc.init_model_parallel( |
| self_name="worker{}".format(self.rank), |
| backend="invalid", |
| init_method=self.init_method, |
| self_rank=self.rank, |
| worker_name_to_id=self.worker_name_to_id, |
| ) |
| |
| @dist_init(setup_model_parallel=False) |
| def test_invalid_names(self): |
| dist.init_process_group( |
| backend=dist.Backend.GLOO, |
| init_method=self.init_method, |
| rank=self.rank, |
| world_size=self.world_size, |
| ) |
| |
| with self.assertRaisesRegex(RuntimeError, "Worker name must match"): |
| rpc.init_model_parallel( |
| self_name="abc*", |
| backend=TEST_CONFIG.rpc_backend, |
| init_method=self.init_method, |
| self_rank=self.rank, |
| worker_name_to_id=self.worker_name_to_id, |
| num_send_recv_threads=16, |
| ) |
| |
| with self.assertRaisesRegex(RuntimeError, "Worker name must match"): |
| rpc.init_model_parallel( |
| self_name=" ", |
| backend=TEST_CONFIG.rpc_backend, |
| init_method=self.init_method, |
| self_rank=self.rank, |
| worker_name_to_id=self.worker_name_to_id, |
| num_send_recv_threads=16, |
| ) |
| |
| with self.assertRaisesRegex(RuntimeError, "must be non-empty"): |
| rpc.init_model_parallel( |
| self_name="", |
| backend=TEST_CONFIG.rpc_backend, |
| init_method=self.init_method, |
| self_rank=self.rank, |
| worker_name_to_id=self.worker_name_to_id, |
| num_send_recv_threads=16, |
| ) |
| |
| # If the number in the message does not match, it is likely that the |
| # value of MAX_NAME_LEN in RPC WorkerInfo has changed. |
| with self.assertRaisesRegex(RuntimeError, "shorter than 128"): |
| rpc.init_model_parallel( |
| self_name="".join(["a" for _ in range(500)]), |
| backend=TEST_CONFIG.rpc_backend, |
| init_method=self.init_method, |
| self_rank=self.rank, |
| worker_name_to_id=self.worker_name_to_id, |
| num_send_recv_threads=16, |
| ) |
| |
| from torch.distributed.rpc.api import _agent |
| self.assertEqual(_agent, None) |
| # join_rpc() should not do anything as _agent is None |
| rpc.join_rpc() |
| # We need this barrier here because although init_process_group is |
| # blocking, it does not guarantee that all ranks are done with |
| # initialization after the call. We did run into issues with it where |
| # rank 3 crashed with "connection closed by peer" RuntimeError, which is |
| # caused by other ranks exit before rank 3 is ready. This can be fixed |
| # by adding a collective call to sync all processes. |
| # |
| # We decided not fixing this issue in init_process_group because it |
| # would add extra overhead to the call, and normal use cases won't |
| # create a progress group and exit without doing anything. Hence, it is |
| # not worthy to introduce the overhead just for this test case. |
| dist.barrier() |
| |
| @dist_init(setup_model_parallel=True) |
| def test_add(self): |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| ret = rpc.rpc_sync( |
| "worker{}".format(dst_rank), |
| torch.add, |
| args=(torch.ones(n, n), torch.ones(n, n)), |
| ) |
| self.assertEqual(ret, torch.ones(n, n) * 2) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_add_with_id(self): |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| workder_info = rpc.get_worker_info("worker{}".format(dst_rank)) |
| |
| ret = rpc.rpc_sync( |
| workder_info, torch.add, args=(torch.ones(n, n), torch.ones(n, n)) |
| ) |
| self.assertEqual(ret, torch.ones(n, n) * 2) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_scalar_add(self): |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| ret = rpc.rpc_sync( |
| "worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), n) |
| ) |
| self.assertEqual(ret, (torch.ones(n, n) + n)) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_async_add(self): |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| fut = rpc.rpc_async( |
| "worker{}".format(dst_rank), |
| torch.add, |
| args=(torch.ones(n, n), torch.ones(n, n)), |
| ) |
| self.assertEqual(fut.wait(), torch.ones(n, n) * 2) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_nonzero(self): |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| x = torch.ones(self.world_size, self.world_size) |
| x[self.rank][self.rank] = 0 |
| ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.nonzero, args=(x,)) |
| self.assertEqual(ret, x.nonzero()) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_multi_rpc(self): |
| dst_rank = (self.rank + 1) % self.world_size |
| for i in range(20): |
| n = i + self.rank + 1 |
| ret = rpc.rpc_sync( |
| "worker{}".format(dst_rank), |
| torch.add, |
| args=(torch.ones(n, n), torch.ones(n, n)), |
| ) |
| self.assertEqual(ret, torch.ones(n, n) * 2) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_sync_rpc(self): |
| dst_rank = (self.rank + 1) % self.world_size |
| for i in range(20): |
| rpc.sync_rpc() |
| n = i + self.rank + 1 |
| ret1 = rpc.rpc_sync( |
| "worker{}".format(dst_rank), |
| torch.add, |
| args=(torch.ones(n, n), torch.ones(n, n)), |
| ) |
| rpc.sync_rpc() |
| ret2 = rpc.rpc_sync( |
| "worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 2) |
| ) |
| rpc.sync_rpc() |
| self.assertEqual(ret1, torch.ones(n, n) * 2) |
| self.assertEqual(ret2, torch.ones(n, n) * 3) |
| |
| @dist_init(setup_model_parallel=False) |
| def test_join_rpc(self): |
| # Initialize RPC. |
| dist.init_process_group( |
| backend="gloo", |
| init_method=self.init_method, |
| rank=self.rank, |
| world_size=self.world_size, |
| ) |
| rpc.init_model_parallel( |
| self_name="worker%d" % self.rank, |
| backend=TEST_CONFIG.rpc_backend, |
| init_method=self.init_method, |
| self_rank=self.rank, |
| worker_name_to_id=self.worker_name_to_id, |
| ) |
| |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| ret = rpc.rpc_sync( |
| "worker{}".format(dst_rank), |
| torch.add, |
| args=(torch.ones(n, n), torch.ones(n, n)), |
| ) |
| self.assertEqual(ret, torch.ones(n, n) * 2) |
| rpc.join_rpc() |
| |
| with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"): |
| rpc.rpc_sync( |
| "worker{}".format(dst_rank), |
| torch.add, |
| args=(torch.ones(n, n), torch.ones(n, n)), |
| ) |
| |
| # it's safe to call join_rpc() multiple times |
| rpc.join_rpc() |
| |
| @dist_init(setup_model_parallel=True) |
| def test_expected_src(self): |
| dst_rank = (self.rank + 1) % self.world_size |
| expected_src_rank = (self.rank - 1) % self.world_size |
| ret = rpc.rpc_sync("worker{}".format(dst_rank), set_value, args=(self.rank,)) |
| value = VALUE_FUTURE.result() |
| self.assertEqual(value, expected_src_rank) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_py_built_in(self): |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| ret = rpc.rpc_sync("worker{}".format(dst_rank), min, args=(n, n + 1, n + 2)) |
| self.assertEqual(ret, min(n, n + 1, n + 2)) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_py_user_defined(self): |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| ret = rpc.rpc_sync( |
| "worker{}".format(dst_rank), |
| my_function, |
| kwargs={"a": n, "b": n + 1, "c": n + 2}, |
| ) |
| self.assertEqual(ret, my_function(n, n + 1, n + 2)) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_py_class_constructor(self): |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| ret = rpc.rpc_sync("worker{}".format(dst_rank), MyClass, args=(n,)) |
| self.assertEqual(ret.a, n) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_py_class_instance_method(self): |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| ret = rpc.rpc_sync( |
| "worker{}".format(dst_rank), MyClass(2).my_instance_method, args=(n,) |
| ) |
| self.assertEqual(ret, MyClass(2).my_instance_method(n)) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_py_class_method(self): |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| ret = rpc.rpc_sync( |
| "worker{}".format(dst_rank), MyClass.my_class_method, args=(n, n + 1) |
| ) |
| self.assertEqual(ret, MyClass.my_class_method(n, n + 1)) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_py_class_static_method(self): |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| ret = rpc.rpc_sync( |
| "worker{}".format(dst_rank), MyClass.my_static_method, args=(n + 10,) |
| ) |
| self.assertEqual(ret, MyClass.my_static_method(n + 10)) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_py_multi_async_call(self): |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| dst_worker_info = rpc.get_worker_info("worker{}".format(dst_rank)) |
| fut1 = rpc.rpc_async(dst_worker_info, MyClass.my_static_method, args=(n + 10,)) |
| fut2 = rpc.rpc_async(dst_worker_info, min, args=(n, n + 1, n + 2)) |
| self.assertEqual(fut1.wait(), MyClass.my_static_method(n + 10)) |
| self.assertEqual(fut2.wait(), min(n, n + 1, n + 2)) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_py_no_return_result(self): |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| ret = rpc.rpc_sync("worker{}".format(dst_rank), no_result) |
| self.assertEqual(ret, no_result()) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_py_tensors(self): |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| ret = rpc.rpc_sync( |
| "worker{}".format(dst_rank), |
| my_tensor_function, |
| args=(torch.ones(n, n), torch.ones(n, n)), |
| ) |
| self.assertEqual(ret, my_tensor_function(torch.ones(n, n), torch.ones(n, n))) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_py_tensors_multi_async_call(self): |
| futs = [] |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| for i in range(100): |
| fut = rpc.rpc_async( |
| "worker{}".format(dst_rank), |
| my_tensor_function, |
| args=(torch.ones(i, i), torch.ones(i, i)), |
| ) |
| futs.append(fut) |
| |
| j = 0 |
| for fut in futs: |
| self.assertEqual( |
| fut.wait(), my_tensor_function(torch.ones(j, j), torch.ones(j, j)) |
| ) |
| j += 1 |
| |
| @dist_init(setup_model_parallel=True) |
| def test_py_tensors_in_container(self): |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| a = [torch.ones(n, n), torch.ones(n, n)] |
| b = TensorClass(build_complex_tensors()) |
| c = {"foo": torch.ones(n, n), "bar": torch.ones(n, n)} |
| ret = rpc.rpc_sync( |
| "worker{}".format(dst_rank), my_complex_tensor_function, args=(a, b, c) |
| ) |
| self.assertEqual(ret, my_complex_tensor_function(a, b, c)) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_py_nested_pickle(self): |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| |
| ret = rpc.rpc_sync( |
| "worker{}".format(dst_rank), |
| run_nested_pickle, |
| args=(MyPickleClass(), torch.ones(2, 2)), |
| ) |
| |
| m = MyPickleClass() |
| m.set(my_tensor_function(torch.ones(2, 2), torch.ones(2, 2))) |
| self.assertEqual(ret, run_nested_pickle(m, torch.ones(2, 2))) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_py_function_exception(self): |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| with self.assertRaisesRegex(Exception, "TypeError"): |
| ret = rpc.rpc_sync("worker{}".format(dst_rank), no_result, args=(10,)) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_py_raise_in_user_func(self): |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| fut = rpc.rpc_async("worker{}".format(dst_rank), raise_func) |
| with self.assertRaisesRegex(Exception, "ValueError"): |
| fut.wait() |
| |
| @dist_init(setup_model_parallel=True) |
| def test_nested_rpc(self): |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| ret = rpc.rpc_sync( |
| "worker{}".format(dst_rank), |
| nested_rpc, |
| args=("worker{}".format(self.rank),), |
| ) |
| self.assertEqual(ret, torch.ones(2, 2) + 1) |
| |
| def _stress_test_rpc(self, f, repeat=1000, args=()): |
| import time |
| |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| futs = [] |
| tik = time.time() |
| for _ in range(repeat): |
| fut = rpc.rpc_async("worker{}".format(dst_rank), f, args=args) |
| futs.append(fut) |
| |
| for fut in futs: |
| self.assertEqual(fut.wait(), 0) |
| tok = time.time() |
| print( |
| "Rank {} finished testing {} {} times in {} seconds.".format( |
| self.rank, f.__name__, repeat, tok - tik |
| ) |
| ) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_stress_light_rpc(self): |
| self._stress_test_rpc(light_rpc) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_stress_heavy_rpc(self): |
| self._stress_test_rpc(heavy_rpc, repeat=20, args=(torch.ones(100, 100),)) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_builtin_remote_ret(self): |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| rref = rpc.remote( |
| "worker{}".format(dst_rank), |
| torch.add, |
| args=(torch.ones(n, n), torch.ones(n, n)), |
| ) |
| self.assertEqual(rref.to_here().wait(), torch.ones(n, n) * 2) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_asymmetric_load_with_join(self): |
| """Test graceful termination.""" |
| # worker0 drives and waits for worker1 and worker2 |
| # throughout the test. |
| if self.rank == 0: |
| assert self.world_size >= 3 |
| |
| num_repeat = 200 |
| futs = [] |
| |
| # Phase 1: Only worker1 has workload. |
| dst = "worker1" |
| for _ in range(num_repeat): |
| fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),)) |
| futs.append(fut) |
| |
| for fut in futs: |
| fut.wait() |
| self.assertEqual(fut.wait(), 0) |
| |
| # Phase 2: Only worker2 has workload. |
| # If join is not correctly implemented, |
| # worker2 should be closed by now. |
| dst = "worker2" |
| for _ in range(num_repeat): |
| fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),)) |
| futs.append(fut) |
| |
| for fut in futs: |
| fut.wait() |
| self.assertEqual(fut.wait(), 0) |
| |
| def _test_multi_remote_call(self, fn, args_fn=lambda x: (), kwargs_fn=lambda x: {}): |
| m = 10 |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| rrefs = [] |
| expected = [] |
| for i in range(m): |
| n = n + i |
| rrefs.append( |
| rpc.remote( |
| "worker{}".format(dst_rank), |
| fn, |
| args=args_fn(n), |
| kwargs=kwargs_fn(n), |
| ) |
| ) |
| expected.append(fn(*args_fn(n), **kwargs_fn(n))) |
| |
| for i in range(m): |
| self.assertEqual(rrefs[i].to_here().wait(), expected[i]) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_multi_builtin_remote_ret(self): |
| def args_fn(n): |
| return (torch.ones(n, n), torch.ones(n, n)) |
| |
| self._test_multi_remote_call(torch.add, args_fn=args_fn) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_py_udf_remote(self): |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| rref = rpc.remote( |
| "worker{}".format(dst_rank), |
| my_function, |
| kwargs={"a": n, "b": n + 1, "c": n + 2}, |
| ) |
| self.assertEqual(rref.to_here().wait(), my_function(n, n + 1, n + 2)) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_multi_py_udf_remote(self): |
| def kwargs_fn(n): |
| return {"a": torch.ones(n, n), "b": torch.ones(n, n), "c": torch.ones(n, n)} |
| |
| self._test_multi_remote_call(my_function, kwargs_fn=kwargs_fn) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_py_rref_args(self): |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| rref_a = rpc.remote( |
| "worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 2) |
| ) |
| rref_b = rpc.remote( |
| "worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 1) |
| ) |
| rref_c = rpc.remote( |
| "worker{}".format(dst_rank), my_rref_function, args=(rref_a, rref_b) |
| ) |
| self.assertEqual(rref_c.to_here().wait(), torch.ones(n, n) + 4) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_py_rref_args_user_share(self): |
| n = self.rank + 1 |
| owner_rank = n % self.world_size |
| user_rank = (n + 1) % self.world_size |
| rref_a = rpc.remote( |
| "worker{}".format(owner_rank), my_function, args=(torch.ones(n, n), 2, 0) |
| ) |
| rref_b = rpc.remote( |
| "worker{}".format(owner_rank), my_function, args=(torch.ones(n, n), 1, 0) |
| ) |
| rref_c = rpc.remote( |
| "worker{}".format(user_rank), my_rref_function, args=(rref_a, rref_b) |
| ) |
| self.assertEqual(rref_c.to_here().wait(), torch.ones(n, n) + 4) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_py_rpc_rref_args(self): |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| rref_a = rpc.remote( |
| "worker{}".format(dst_rank), my_function, args=(torch.ones(n, n), 2, 0) |
| ) |
| rref_b = rpc.remote( |
| "worker{}".format(dst_rank), my_function, args=(torch.ones(n, n), 1, 0) |
| ) |
| |
| c = rpc.rpc_sync( |
| "worker{}".format(dst_rank), my_rref_function, args=(rref_a, rref_b) |
| ) |
| |
| self.assertEqual(c, torch.ones(n, n) + 4) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_nested_remote(self): |
| n = self.rank + 1 |
| dst_rank1 = n % self.world_size |
| dst_rank2 = (n + 1) % self.world_size |
| |
| rref = rpc.remote( |
| "worker{}".format(dst_rank1), |
| nested_remote, |
| args=("worker{}".format(dst_rank2),), |
| ) |
| self.assertEqual(rref.to_here().wait(), torch.ones(2, 2) + 3) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_nested_rref(self): |
| n = self.rank + 1 |
| dst_rank1 = n % self.world_size |
| dst_rank2 = (n + 1) % self.world_size |
| rref_of_rrefs = rpc.remote( |
| "worker{}".format(dst_rank1), |
| nested_rref, |
| args=("worker{}".format(dst_rank2),), |
| ) |
| rrefs = rref_of_rrefs.to_here().wait() |
| self.assertEqual(len(rrefs), 2) |
| self.assertEqual(rrefs[0].to_here().wait(), torch.ones(2, 2) + 1) |
| self.assertEqual(rrefs[1].to_here().wait(), torch.ones(2, 2) + 2) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_nested_rref_stress(self): |
| n = self.rank + 1 |
| dst_rank1 = n % self.world_size |
| dst_rank2 = (n + 1) % self.world_size |
| all_rrefs = [] |
| for _ in range(20): |
| all_rrefs.append( |
| rpc.remote( |
| "worker{}".format(dst_rank1), |
| nested_rref, |
| args=("worker{}".format(dst_rank2),), |
| ) |
| ) |
| |
| for i in range(20): |
| rref_of_rrefs = all_rrefs[i] |
| rrefs = rref_of_rrefs.to_here().wait() |
| self.assertEqual(len(rrefs), 2) |
| self.assertEqual(rrefs[0].to_here().wait(), torch.ones(2, 2) + 1) |
| self.assertEqual(rrefs[1].to_here().wait(), torch.ones(2, 2) + 2) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_multi_layer_nested_async_rpc(self): |
| # This test will exit right away, but there will be a chain of async |
| # RPCs. The termination algorithm should detect those messages properly. |
| # Otherwise, some peer could exit early, leaving others to timeout |
| # errors or connection closed errors. |
| ttl = 20 |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| |
| multi_layer_nested_async_rpc(dst_rank, self.world_size, ttl) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_remote_with_exception(self): |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| rref = rpc.remote("worker{}".format(dst_rank), raise_func) |
| with self.assertRaisesRegex(Exception, "ValueError"): |
| rref.to_here().wait() |
| |
| @dist_init(setup_model_parallel=True) |
| def test_rpc_return_rref(self): |
| n = self.rank + 1 |
| dst_rank1 = n % self.world_size |
| dst_rank2 = (n + 1) % self.world_size |
| rref = rpc.rpc_sync( |
| "worker{}".format(dst_rank1), |
| rpc_return_rref, |
| args=("worker{}".format(dst_rank2),), |
| ) |
| self.assertEqual(rref.to_here().wait(), torch.ones(2, 2) + 1) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_rref_forward_chain(self): |
| ttl = 8 |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| |
| rref = rpc.remote( |
| "worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 1) |
| ) |
| |
| ret_rref = rref_forward_chain(dst_rank, self.world_size, rref, ttl) |
| |
| for i in range(ttl): |
| self.assertEqual(len(ret_rref), 1) |
| ret_rref = ret_rref[0].to_here().wait() |
| |
| ret = ret_rref |
| self.assertEqual(ret, torch.add(torch.ones(n, n), 1)) |
| |
| @dist_init(setup_model_parallel=True) |
| def test_remote_same_worker(self): |
| n = self.rank + 1 |
| dst_rank = n % self.world_size |
| rref_a = rpc.remote( |
| "worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 2) |
| ) |
| rref_b = rpc.remote( |
| "worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 1) |
| ) |
| rref_c = rpc.remote( |
| "worker{}".format(dst_rank), my_rref_function, args=(rref_a, rref_b) |
| ) |
| self.assertEqual(rref_c.to_here().wait(), torch.ones(n, n) + 4) |
| |
| def test_requires_process_group_agent_decorator(self): |
| @requires_process_group_agent("test_func did not run") |
| def test_func(): |
| return "expected result" |
| |
| if TEST_CONFIG.rpc_backend == RpcBackend.PROCESS_GROUP: |
| self.assertEqual(test_func(), "expected result") |