blob: 432bc557cae77e5664abc4b962d2bfe4d54b4bbc [file] [log] [blame]
from __future__ import absolute_import, division, print_function, unicode_literals
import concurrent.futures
from datetime import timedelta
import sys
import time
import unittest
from collections import namedtuple
from unittest import mock
import torch
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch.distributed.rpc import RRef
from common_utils import load_tests
import dist_utils
from dist_utils import dist_init
from torch.distributed.rpc.api import _use_rpc_pickler
from torch.distributed.rpc.internal import PythonUDF, _internal_rpc_pickler
from rpc_agent_test_fixture import RpcAgentTestFixture
def requires_process_group_agent(message=""):
def decorator(old_func):
return unittest.skipUnless(
dist_utils.TEST_CONFIG.rpc_backend_name == "PROCESS_GROUP", message
)(old_func)
return decorator
VALUE_FUTURE = concurrent.futures.Future()
DONE_FUTURE = concurrent.futures.Future()
def _stub_construct_rpc_backend_options_handler(
**kwargs
):
return mock.Mock() # RpcBackendOptions.
def _stub_start_rpc_backend_handler(
store, name, rank, world_size, rpc_backend_options
):
return mock.Mock() # RpcAgent.
def set_value(value):
VALUE_FUTURE.set_result(value)
def set_and_check_done(value):
VALUE_FUTURE.set_result(value)
return DONE_FUTURE.result()
# it is used to test python user defined function over rpc
# classes and functions are used to test python user defined class and
# methods over rpc
TensorClass = namedtuple("TensorClass", ["tensors"])
class MyPickleClass:
def __init__(self):
self.t = None
def __getstate__(self):
(pickled_python_udf, tensors) = _internal_rpc_pickler.serialize(
PythonUDF(my_tensor_function, (torch.ones(2, 2), torch.ones(2, 2)), None)
)
return (pickled_python_udf, tensors)
def __setstate__(self, obj):
python_udf = _internal_rpc_pickler.deserialize(obj[0], obj[1])
result = python_udf.func(python_udf.args[0], python_udf.args[1])
self.t = result
def set(self, val):
self.t = val
class MyClass:
def __init__(self, a):
self.a = a
def my_instance_method(self, b):
return self.a + b
@classmethod
def my_class_method(cls, d, e):
return d + e
@staticmethod
def my_static_method(f):
return f > 10
def increment_value(self, increment):
self.a += increment
def get_value(self):
return self.a
def _call_method_on_rref(method, rref, *args, **kwargs):
return method(rref.local_value(), *args, **kwargs)
def get_rref_list(values):
return [RRef(MyClass(a)) for a in values]
def add_rref_to_value(rref, value):
return rref.to_here() + value
def run_nested_pickle(pickle_cls_instance, tensor):
return pickle_cls_instance.t + tensor
def build_complex_tensors():
a = torch.ones(3, 3)
b = [a, a]
c = [b, b]
d = [a, b]
e = {a: d}
return [a, b, c, d, e]
def my_function(a, b, c):
return a + b + c
def my_tensor_function(a, b):
return a + b
def my_sleep_func(seconds=1):
time.sleep(seconds)
def my_complex_tensor_function(list_input, tensor_class_input, dict_input):
res = list_input[0]
for t in list_input:
res += t
for k, v in dict_input.items():
res += v
complex_tensors = tensor_class_input.tensors
return (res, complex_tensors[0], complex_tensors[1], complex_tensors[2])
def my_rref_function(rref_a, rref_b):
return rref_a.to_here() + rref_b.to_here()
def no_result():
print("do nothing")
def nested_rpc(dst):
return rpc.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1))
def multi_layer_nested_async_rpc(dst, world_size, ttl):
# this method returns immediately without blocking the callee, but will
# generate additional requests.
if ttl > 0:
current_dst = "worker{}".format(dst)
next_dst = (dst + 1) % world_size
rpc.rpc_async(
current_dst,
multi_layer_nested_async_rpc,
args=(next_dst, world_size, ttl - 1),
)
return 0
def nested_rref(dst):
return (
rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1)),
rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 2)),
)
def nested_remote(dst):
rref = rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 3))
return rref.to_here()
def rref_forward_chain(dst, world_size, rref, ttl):
if ttl > 0:
current_dst = "worker{}".format(dst)
next_dst = (dst + 1) % world_size
ret_rref = rpc.remote(
current_dst, rref_forward_chain, args=(next_dst, world_size, rref, ttl - 1)
)
return [ret_rref]
else:
return rref.to_here()
def rpc_return_rref(dst):
return rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1))
def light_rpc():
return 0
def heavy_rpc(tensor):
for i in range(1, 100):
tensor *= i
tensor /= i + 1
return 0
def raise_func():
raise ValueError("Expected error")
global_rref = None
def set_global_rref(rref):
global global_rref
global_rref = rref
def clear_global_rref():
global global_rref
global_rref = None
# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests
@unittest.skipIf(
sys.version_info < (3, 0),
"Pytorch distributed rpc package " "does not support python2",
)
class RpcTest(RpcAgentTestFixture):
@dist_init
def test_worker_id(self):
n = self.rank + 1
peer_rank = n % self.world_size
self_worker_info = rpc.get_worker_info()
peer_worker_info = rpc.get_worker_info("worker{}".format(peer_rank))
self.assertEqual(self_worker_info.name, "worker{}".format(self.rank))
self.assertEqual(peer_worker_info.name, "worker{}".format(peer_rank))
with self.assertRaisesRegex(RuntimeError, "Unknown destination worker"):
unknown_worker_id = rpc.get_worker_info("WorkerUnknown")
@dist_init
def test_get_worker_infos(self):
worker_infos = rpc.api._agent.get_worker_infos()
worker_names = {
worker_info.name for worker_info in worker_infos
}
expected_worker_names = {
"worker{}".format(rank) for rank in range(self.world_size)
}
self.assertEqual(worker_names, expected_worker_names)
worker_ids = {
worker_info.id for worker_info in worker_infos
}
expected_worker_ids = {
rank for rank in range(self.world_size)
}
self.assertEqual(worker_ids, expected_worker_ids)
@dist_init
def test_self_add(self):
self_worker_info = rpc.get_worker_info()
self_worker_name = "worker{}".format(self.rank)
fut = rpc.rpc_async(self_worker_info, torch.add, args=(torch.ones(2, 2), 1))
ret = rpc.rpc_sync(self_worker_info, torch.add, args=(torch.ones(2, 2), 1))
self.assertEqual(fut.wait(), torch.ones(2, 2) + 1)
self.assertEqual(ret, torch.ones(2, 2) + 1)
@dist_init
def test_self_py_udf_remote(self):
self_worker_info = rpc.get_worker_info()
rref = rpc.remote(self_worker_info, my_function, args=(torch.ones(2, 2), 1, 3))
self.assertEqual(rref.to_here(), torch.ones(2, 2) + 1 + 3)
def _test_self_remote_rref_as_rpc_arg(self, dst):
self_worker_info = rpc.get_worker_info()
rref = rpc.remote(self_worker_info, my_function, args=(torch.ones(2, 2), 1, 3))
fut = rpc.rpc_async(dst, add_rref_to_value, args=(rref, torch.ones(2, 2)))
ret = rpc.rpc_sync(dst, add_rref_to_value, args=(rref, torch.ones(2, 2) + 1))
self.assertEqual(ret, torch.ones(2, 2) + 1 + 3 + torch.ones(2, 2) + 1)
self.assertEqual(fut.wait(), torch.ones(2, 2) + 1 + 3 + torch.ones(2, 2))
@dist_init
def test_self_remote_rref_as_rpc_arg(self):
dst = "worker{}".format((self.rank + 1) % self.world_size)
self._test_self_remote_rref_as_rpc_arg(dst)
@dist_init
def test_self_remote_rref_as_self_rpc_arg(self):
self._test_self_remote_rref_as_rpc_arg(rpc.get_worker_info())
def _test_self_remote_rref_as_remote_arg(self, dst):
self_worker_info = rpc.get_worker_info()
rref = rpc.remote(self_worker_info, my_function, args=(torch.ones(2, 2), 1, 3))
ret_rref = rpc.remote(dst, add_rref_to_value, args=(rref, torch.ones(2, 2)))
self.assertEqual(ret_rref.to_here(), torch.ones(2, 2) + 1 + 3 + torch.ones(2, 2))
@dist_init
def test_self_remote_rref_as_remote_arg(self):
dst = "worker{}".format((self.rank + 1) % self.world_size)
self._test_self_remote_rref_as_remote_arg(dst)
@dist_init
def test_self_remote_rref_as_self_remote_arg(self):
self._test_self_remote_rref_as_remote_arg(rpc.get_worker_info())
@mock.patch.object(torch.distributed.autograd, "_init")
@mock.patch.object(torch.distributed.rpc.api, "_start_rpc_agent")
@dist_init(setup_rpc=False)
def test_register_rpc_backend_and_start_rpc_backend(
self, mock_rpc_agent, mock_dist_autograd_init
):
backend_name = "stub_backend"
backend = rpc.backend_registry.register_backend(
backend_name,
_stub_construct_rpc_backend_options_handler,
_stub_start_rpc_backend_handler,
)
with self.assertRaisesRegex(
RuntimeError, "^RPC backend .+: already registered$"
):
backend = rpc.backend_registry.register_backend(
backend_name,
_stub_construct_rpc_backend_options_handler,
_stub_start_rpc_backend_handler,
)
rpc.init_rpc(
name="worker1",
backend=backend,
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=self.rpc_backend_options,
)
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
@dist_init(setup_rpc=False)
def test_duplicate_name(self):
with self.assertRaisesRegex(RuntimeError, "is not unique"):
store, _, _ = next(torch.distributed.rendezvous(
self.init_method, rank=self.rank, world_size=self.world_size
))
rpc._init_rpc_backend(
backend=self.rpc_backend,
store=store,
name="duplicate_name",
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=self.rpc_backend_options,
)
rpc.shutdown()
@dist_init(setup_rpc=False)
def test_reinit(self):
rpc.init_rpc(
name="worker{}".format(self.rank),
backend=self.rpc_backend,
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=self.rpc_backend_options,
)
# This is for the below `dist.barrier`.
# For `RpcAgent` other than `ProcessGroupAgent`,
# no `_default_pg` is initialized.
if not dist.is_initialized():
dist.init_process_group(
backend="gloo",
init_method=self.init_method,
rank=self.rank,
world_size=self.world_size,
)
# Wait for all init to complete.
dist.barrier()
with self.assertRaisesRegex(RuntimeError, "is already initialized"):
rpc.init_rpc(
name="worker{}".format(self.rank),
backend=self.rpc_backend,
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=self.rpc_backend_options,
)
rpc.shutdown()
@dist_init(setup_rpc=False)
def test_invalid_names(self):
from torch.distributed.rpc import WorkerInfo
worker_id = 0
with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
info = WorkerInfo("abc*", worker_id)
with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
info = WorkerInfo(" ", worker_id)
with self.assertRaisesRegex(RuntimeError, "must be non-empty"):
info = WorkerInfo("", worker_id)
# If the number in the message does not match, it is likely that the
# value of MAX_NAME_LEN in RPC WorkerInfo has changed.
with self.assertRaisesRegex(RuntimeError, "shorter than 128"):
info = WorkerInfo("".join(["a" for i in range(500)]), worker_id)
@dist_init
def test_add(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = rpc.rpc_sync(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
)
self.assertEqual(ret, torch.ones(n, n) * 2)
@dist_init
def test_add_with_id(self):
n = self.rank + 1
dst_rank = n % self.world_size
workder_info = rpc.get_worker_info("worker{}".format(dst_rank))
ret = rpc.rpc_sync(
workder_info, torch.add, args=(torch.ones(n, n), torch.ones(n, n))
)
self.assertEqual(ret, torch.ones(n, n) * 2)
@dist_init
def test_scalar_add(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = rpc.rpc_sync(
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), n)
)
self.assertEqual(ret, (torch.ones(n, n) + n))
@dist_init
def test_async_add(self):
n = self.rank + 1
dst_rank = n % self.world_size
fut = rpc.rpc_async(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
)
self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
@dist_init
def test_nonzero(self):
n = self.rank + 1
dst_rank = n % self.world_size
x = torch.ones(self.world_size, self.world_size)
x[self.rank][self.rank] = 0
ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.nonzero, args=(x,))
self.assertEqual(ret, x.nonzero())
@dist_init
def test_multi_rpc(self):
dst_rank = (self.rank + 1) % self.world_size
for i in range(20):
n = i + self.rank + 1
ret = rpc.rpc_sync(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
)
self.assertEqual(ret, torch.ones(n, n) * 2)
@dist_init(setup_rpc=False)
def test_shutdown(self):
# Initialize RPC.
rpc.init_rpc(
name="worker%d" % self.rank,
backend=self.rpc_backend,
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=self.rpc_backend_options,
)
n = self.rank + 1
dst_rank = n % self.world_size
ret = rpc.rpc_sync(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
)
self.assertEqual(ret, torch.ones(n, n) * 2)
rpc.shutdown()
with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"):
rpc.rpc_sync(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
)
# it's safe to call shutdown() multiple times
rpc.shutdown()
@dist_init
def test_expected_src(self):
dst_rank = (self.rank + 1) % self.world_size
expected_src_rank = (self.rank - 1) % self.world_size
ret = rpc.rpc_sync("worker{}".format(dst_rank), set_value, args=(self.rank,))
value = VALUE_FUTURE.result()
self.assertEqual(value, expected_src_rank)
@dist_init
def test_py_built_in(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = rpc.rpc_sync("worker{}".format(dst_rank), min, args=(n, n + 1, n + 2))
self.assertEqual(ret, min(n, n + 1, n + 2))
@dist_init
def test_py_user_defined(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = rpc.rpc_sync(
"worker{}".format(dst_rank),
my_function,
kwargs={"a": n, "b": n + 1, "c": n + 2},
)
self.assertEqual(ret, my_function(n, n + 1, n + 2))
@dist_init
def test_py_class_constructor(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = rpc.rpc_sync("worker{}".format(dst_rank), MyClass, args=(n,))
self.assertEqual(ret.a, n)
@dist_init
def test_py_class_instance_method(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = rpc.rpc_sync(
"worker{}".format(dst_rank), MyClass(2).my_instance_method, args=(n,)
)
self.assertEqual(ret, MyClass(2).my_instance_method(n))
@dist_init
def test_py_class_method(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = rpc.rpc_sync(
"worker{}".format(dst_rank), MyClass.my_class_method, args=(n, n + 1)
)
self.assertEqual(ret, MyClass.my_class_method(n, n + 1))
@dist_init
def test_py_class_static_method(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = rpc.rpc_sync(
"worker{}".format(dst_rank), MyClass.my_static_method, args=(n + 10,)
)
self.assertEqual(ret, MyClass.my_static_method(n + 10))
@dist_init
def test_py_multi_async_call(self):
n = self.rank + 1
dst_rank = n % self.world_size
dst_worker_info = rpc.get_worker_info("worker{}".format(dst_rank))
fut1 = rpc.rpc_async(dst_worker_info, MyClass.my_static_method, args=(n + 10,))
fut2 = rpc.rpc_async(dst_worker_info, min, args=(n, n + 1, n + 2))
self.assertEqual(fut1.wait(), MyClass.my_static_method(n + 10))
self.assertEqual(fut2.wait(), min(n, n + 1, n + 2))
@dist_init
def test_py_no_return_result(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = rpc.rpc_sync("worker{}".format(dst_rank), no_result)
self.assertEqual(ret, no_result())
@dist_init
def test_py_tensors(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = rpc.rpc_sync(
"worker{}".format(dst_rank),
my_tensor_function,
args=(torch.ones(n, n), torch.ones(n, n)),
)
self.assertEqual(ret, my_tensor_function(torch.ones(n, n), torch.ones(n, n)))
@dist_init
def test_py_tensors_multi_async_call(self):
futs = []
n = self.rank + 1
dst_rank = n % self.world_size
for i in range(100):
fut = rpc.rpc_async(
"worker{}".format(dst_rank),
my_tensor_function,
args=(torch.ones(i, i), torch.ones(i, i)),
)
futs.append(fut)
j = 0
for fut in futs:
self.assertEqual(
fut.wait(), my_tensor_function(torch.ones(j, j), torch.ones(j, j))
)
j += 1
@dist_init
def test_py_tensors_in_container(self):
n = self.rank + 1
dst_rank = n % self.world_size
a = [torch.ones(n, n), torch.ones(n, n)]
b = TensorClass(build_complex_tensors())
c = {"foo": torch.ones(n, n), "bar": torch.ones(n, n)}
ret = rpc.rpc_sync(
"worker{}".format(dst_rank), my_complex_tensor_function, args=(a, b, c)
)
self.assertEqual(ret, my_complex_tensor_function(a, b, c))
@dist_init
def test_py_nested_pickle(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = rpc.rpc_sync(
"worker{}".format(dst_rank),
run_nested_pickle,
args=(MyPickleClass(), torch.ones(2, 2)),
)
m = MyPickleClass()
m.set(my_tensor_function(torch.ones(2, 2), torch.ones(2, 2)))
self.assertEqual(ret, run_nested_pickle(m, torch.ones(2, 2)))
@dist_init
def test_py_function_exception(self):
n = self.rank + 1
dst_rank = n % self.world_size
with self.assertRaisesRegex(Exception, "TypeError"):
ret = rpc.rpc_sync("worker{}".format(dst_rank), no_result, args=(10,))
@dist_init
def test_py_raise_in_user_func(self):
n = self.rank + 1
dst_rank = n % self.world_size
fut = rpc.rpc_async("worker{}".format(dst_rank), raise_func)
with self.assertRaisesRegex(Exception, "ValueError"):
fut.wait()
@dist_init
def test_nested_rpc(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = rpc.rpc_sync(
"worker{}".format(dst_rank),
nested_rpc,
args=("worker{}".format(self.rank),),
)
self.assertEqual(ret, torch.ones(2, 2) + 1)
def _stress_test_rpc(self, f, repeat=1000, args=()):
n = self.rank + 1
dst_rank = n % self.world_size
futs = []
tik = time.time()
for _ in range(repeat):
fut = rpc.rpc_async("worker{}".format(dst_rank), f, args=args)
futs.append(fut)
for fut in futs:
self.assertEqual(fut.wait(), 0)
tok = time.time()
print(
"Rank {} finished testing {} {} times in {} seconds.".format(
self.rank, f.__name__, repeat, tok - tik
)
)
@dist_init
def test_stress_light_rpc(self):
self._stress_test_rpc(light_rpc)
@dist_init
def test_stress_heavy_rpc(self):
self._stress_test_rpc(heavy_rpc, repeat=20, args=(torch.ones(100, 100),))
@dist_init
def test_builtin_remote_ret(self):
n = self.rank + 1
dst_rank = n % self.world_size
rref = rpc.remote(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
)
self.assertEqual(rref.to_here(), torch.ones(n, n) * 2)
@dist_init
def test_asymmetric_load_with_join(self):
"""Test graceful termination."""
# worker0 drives and waits for worker1 and worker2
# throughout the test.
if self.rank == 0:
assert self.world_size >= 3
num_repeat = 100
futs = []
# Phase 1: Only worker1 has workload.
dst = "worker1"
for _ in range(num_repeat):
fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),))
futs.append(fut)
for fut in futs:
fut.wait()
self.assertEqual(fut.wait(), 0)
# Phase 2: Only worker2 has workload.
# If join is not correctly implemented,
# worker2 should be closed by now.
dst = "worker2"
for _ in range(num_repeat):
fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),))
futs.append(fut)
for fut in futs:
fut.wait()
self.assertEqual(fut.wait(), 0)
def _test_multi_remote_call(self, fn, args_fn=lambda x: (), kwargs_fn=lambda x: {}):
m = 10
n = self.rank + 1
dst_rank = n % self.world_size
rrefs = []
expected = []
for i in range(m):
n = n + i
rrefs.append(
rpc.remote(
"worker{}".format(dst_rank),
fn,
args=args_fn(n),
kwargs=kwargs_fn(n),
)
)
expected.append(fn(*args_fn(n), **kwargs_fn(n)))
for i in range(m):
self.assertEqual(rrefs[i].to_here(), expected[i])
@dist_init
def test_multi_builtin_remote_ret(self):
def args_fn(n):
return (torch.ones(n, n), torch.ones(n, n))
self._test_multi_remote_call(torch.add, args_fn=args_fn)
@dist_init
def test_py_udf_remote(self):
n = self.rank + 1
dst_rank = n % self.world_size
rref = rpc.remote(
"worker{}".format(dst_rank),
my_function,
kwargs={"a": n, "b": n + 1, "c": n + 2},
)
self.assertEqual(rref.to_here(), my_function(n, n + 1, n + 2))
@dist_init
def test_multi_py_udf_remote(self):
def kwargs_fn(n):
return {"a": torch.ones(n, n), "b": torch.ones(n, n), "c": torch.ones(n, n)}
self._test_multi_remote_call(my_function, kwargs_fn=kwargs_fn)
@dist_init
def test_py_rref_args(self):
n = self.rank + 1
dst_rank = n % self.world_size
rref_a = rpc.remote(
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 2)
)
rref_b = rpc.remote(
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 1)
)
rref_c = rpc.remote(
"worker{}".format(dst_rank), my_rref_function, args=(rref_a, rref_b)
)
self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)
@dist_init
def test_py_rref_args_user_share(self):
n = self.rank + 1
owner_rank = n % self.world_size
user_rank = (n + 1) % self.world_size
rref_a = rpc.remote(
"worker{}".format(owner_rank), my_function, args=(torch.ones(n, n), 2, 0)
)
rref_b = rpc.remote(
"worker{}".format(owner_rank), my_function, args=(torch.ones(n, n), 1, 0)
)
rref_c = rpc.remote(
"worker{}".format(user_rank), my_rref_function, args=(rref_a, rref_b)
)
self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)
@dist_init
def test_py_rpc_rref_args(self):
n = self.rank + 1
dst_rank = n % self.world_size
rref_a = rpc.remote(
"worker{}".format(dst_rank), my_function, args=(torch.ones(n, n), 2, 0)
)
rref_b = rpc.remote(
"worker{}".format(dst_rank), my_function, args=(torch.ones(n, n), 1, 0)
)
c = rpc.rpc_sync(
"worker{}".format(dst_rank), my_rref_function, args=(rref_a, rref_b)
)
self.assertEqual(c, torch.ones(n, n) + 4)
@dist_init
def test_nested_remote(self):
n = self.rank + 1
dst_rank1 = n % self.world_size
dst_rank2 = (n + 1) % self.world_size
rref = rpc.remote(
"worker{}".format(dst_rank1),
nested_remote,
args=("worker{}".format(dst_rank2),),
)
self.assertEqual(rref.to_here(), torch.ones(2, 2) + 3)
@dist_init
def test_nested_rref(self):
n = self.rank + 1
dst_rank1 = n % self.world_size
dst_rank2 = (n + 1) % self.world_size
rref_of_rrefs = rpc.remote(
"worker{}".format(dst_rank1),
nested_rref,
args=("worker{}".format(dst_rank2),),
)
rrefs = rref_of_rrefs.to_here()
self.assertEqual(len(rrefs), 2)
self.assertEqual(rrefs[0].to_here(), torch.ones(2, 2) + 1)
self.assertEqual(rrefs[1].to_here(), torch.ones(2, 2) + 2)
@dist_init
def test_nested_rref_stress(self):
n = self.rank + 1
dst_rank1 = n % self.world_size
dst_rank2 = (n + 1) % self.world_size
all_rrefs = []
for _ in range(20):
all_rrefs.append(
rpc.remote(
"worker{}".format(dst_rank1),
nested_rref,
args=("worker{}".format(dst_rank2),),
)
)
for i in range(20):
rref_of_rrefs = all_rrefs[i]
rrefs = rref_of_rrefs.to_here()
self.assertEqual(len(rrefs), 2)
self.assertEqual(rrefs[0].to_here(), torch.ones(2, 2) + 1)
self.assertEqual(rrefs[1].to_here(), torch.ones(2, 2) + 2)
@dist_init
def test_multi_layer_nested_async_rpc(self):
# This test will exit right away, but there will be a chain of async
# RPCs. The termination algorithm should detect those messages properly.
# Otherwise, some peer could exit early, leaving others to timeout
# errors or connection closed errors.
ttl = 20
n = self.rank + 1
dst_rank = n % self.world_size
multi_layer_nested_async_rpc(dst_rank, self.world_size, ttl)
@dist_init
def test_remote_with_exception(self):
n = self.rank + 1
dst_rank = n % self.world_size
# check ref to other workers
rref = rpc.remote("worker{}".format(dst_rank), raise_func)
with self.assertRaisesRegex(Exception, "ValueError"):
rref.to_here()
# check ref to itself
rref = rpc.remote("worker{}".format(self.rank), no_result, args=(10,))
with self.assertRaisesRegex(Exception, "TypeError"):
rref.to_here()
@dist_init
def test_rpc_return_rref(self):
n = self.rank + 1
dst_rank1 = n % self.world_size
dst_rank2 = (n + 1) % self.world_size
rref = rpc.rpc_sync(
"worker{}".format(dst_rank1),
rpc_return_rref,
args=("worker{}".format(dst_rank2),),
)
self.assertEqual(rref.to_here(), torch.ones(2, 2) + 1)
@dist_init
def test_rref_forward_chain(self):
ttl = 8
n = self.rank + 1
dst_rank = n % self.world_size
rref = rpc.remote(
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 1)
)
ret_rref = rref_forward_chain(dst_rank, self.world_size, rref, ttl)
for i in range(ttl):
self.assertEqual(len(ret_rref), 1)
ret_rref = ret_rref[0].to_here()
ret = ret_rref
self.assertEqual(ret, torch.add(torch.ones(n, n), 1))
@dist_init
def test_local_rref_no_fork(self):
local_rref = RRef(35)
self.assertEqual(local_rref.local_value(), 35)
@dist_init
def test_return_local_rrefs(self):
n = self.rank + 1
dst_rank = n % self.world_size
rref_list = rpc.rpc_sync(
"worker{}".format(dst_rank), get_rref_list, args=(
[1, 2, 3], ))
for rref in rref_list:
rpc.rpc_sync(rref.owner(), _call_method_on_rref, args=(
MyClass.increment_value, rref, 10))
rets = [
rpc.rpc_sync(rref.owner(), _call_method_on_rref, args=(
MyClass.get_value, rref))
for rref in rref_list]
self.assertEqual(rets, [11, 12, 13])
@dist_init
def test_owner_equality(self):
a = RRef(40)
b = RRef(50)
other_rank = (self.rank + 1) % self.world_size
other_a = rpc.remote(
"worker{}".format(other_rank), torch.add, args=(torch.ones(1), 1)
)
other_b = rpc.remote(
"worker{}".format(other_rank), torch.add, args=(torch.ones(1), 1)
)
other_a.to_here() # to ensure clean termination
other_b.to_here()
self.assertNotEqual(a.owner(), 23)
self.assertEqual(other_a.owner(), other_b.owner())
self.assertNotEqual(a.owner(), other_a.owner())
self.assertEqual(other_a.owner(), other_a.owner())
self.assertEqual(other_a.owner(), other_b.owner())
self.assertEqual(a.owner(), a.owner())
self.assertEqual(a.owner(), b.owner())
self.assertEqual(a.owner(), rpc.get_worker_info())
x = dict()
x[a.owner()] = a
x[other_a.owner()] = other_a
self.assertEqual(x[a.owner()], a)
self.assertEqual(x[b.owner()], a)
self.assertEqual(x[other_a.owner()], other_a)
self.assertEqual(x[other_b.owner()], other_a)
self.assertEqual(len(x), 2)
@dist_init
def test_pass_local_rrefs(self):
n = self.rank + 1
dst_rank = n % self.world_size
dst_worker = "worker{}".format(dst_rank)
rref = RRef(40)
self.assertEqual(
rpc.rpc_sync(
dst_worker, add_rref_to_value, args=(rref, 50)), 90)
self.assertEqual(
rpc.rpc_async(
dst_worker, add_rref_to_value, args=(rref, 50)).wait(), 90)
self.assertEqual(
rpc.remote(
dst_worker,
add_rref_to_value,
args=(rref, 50)).to_here(), 90)
@dist_init
def test_remote_same_worker(self):
n = self.rank + 1
dst_rank = n % self.world_size
rref_a = rpc.remote(
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 2)
)
rref_b = rpc.remote(
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 1)
)
rref_c = rpc.remote(
"worker{}".format(dst_rank), my_rref_function, args=(rref_a, rref_b)
)
self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)
@dist_init(setup_rpc=True)
def test_call_method_on_rref(self):
"""
Tests that it is possible to call an instance method on a remote objet
by using rref.owner() as destination of the call.
"""
vals = [10, 2, 5, 7]
dst_rank = (self.rank + 1) % self.world_size
dst_worker = "worker{}".format(dst_rank)
# creates a remote object
rref = rpc.remote(dst_worker, MyClass, args=(vals[0], ))
# modifies state of the remote object
rpc.rpc_sync(rref.owner(), _call_method_on_rref, args=(
MyClass.increment_value, rref, vals[1]))
rpc.rpc_async(rref.owner(), _call_method_on_rref, args=(
MyClass.increment_value, rref, vals[2])).wait()
rpc.remote(rref.owner(), _call_method_on_rref, args=(
MyClass.increment_value, rref, vals[3])).to_here()
# queries state of the remote object
result = rpc.rpc_sync(dst_worker, _call_method_on_rref, args=(
MyClass.get_value, rref))
self.assertEqual(result, sum(vals))
def _test_rref_leak(self, ignore_leak=False):
rpc.init_rpc(
name="worker{}".format(self.rank),
backend=self.rpc_backend,
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=self.rpc_backend_options,
)
# This is for the below `dist.barrier`.
# For `RpcAgent` other than `ProcessGroupAgent`,
# no `_default_pg` is initialized.
if not dist.is_initialized():
dist.init_process_group(
backend="gloo",
init_method=self.init_method,
rank=self.rank,
world_size=self.world_size,
)
# Wait for all init to complete.
dist.barrier()
rref = rpc.remote(
"worker{}".format((self.rank + 1) % self.world_size),
torch.add,
args=(torch.ones(2, 2), 1)
)
if ignore_leak:
import torch.distributed.rpc.api as api
api._ignore_rref_leak = True
rpc.shutdown()
@dist_init(setup_rpc=False)
def test_rref_leak(self):
with self.assertRaisesRegex(RuntimeError, "Leaking RRef"):
self._test_rref_leak()
@dist_init(setup_rpc=False)
def test_ignore_rref_leak(self):
self._test_rref_leak(ignore_leak=True)
@dist_init
def test_rref_str(self):
rref1 = RRef(self.rank)
id_class = "GloballyUniqueId"
self.assertEqual(
"OwnerRRef({}({}, 0))".format(id_class, self.rank),
rref1.__str__()
)
dst_rank = (self.rank + 1) % self.world_size
rref2 = rpc.remote("worker{}".format(dst_rank), torch.add, args=(torch.ones(2, 2), 1))
self.assertEqual(
rref2.__str__(),
"UserRRef(RRefId = {0}({1}, 1), ForkId = {0}({1}, 2))".format(id_class, self.rank)
)
@dist_init
def test_rref_context_debug_info(self):
# This test checks local states that are modified by remote workers.
# This means that we would need barrier before and after every check.
# The barrier before the check makes sure that all previous states are
# cleared globally, the barrier after ensures that no following states
# change gets into the current check.
if not dist.is_initialized():
dist.init_process_group(
backend="gloo",
init_method=self.init_method,
rank=self.rank,
world_size=self.world_size,
)
from torch.distributed.rpc import _rref_context_get_debug_info
# Check 1: local RRef does not update owners_ map
#################################################
rref1 = RRef(self.rank)
# don't need a barrier here as local RRef is handled by this thread
info = _rref_context_get_debug_info()
self.assertIn("num_owner_rrefs", info)
# RRef on local value is not added to context until shared across RPC
self.assertEqual(0, int(info["num_owner_rrefs"]))
# barrier after the check 1
dist.barrier()
# Check 2: Sharing RRef as an arg should update owners_ map
###########################################################
dst_rank = (self.rank + 1) % self.world_size
rpc.rpc_sync(
"worker{}".format(dst_rank),
set_global_rref,
args=(rref1,)
)
# barrier before check 2
dist.barrier()
info = _rref_context_get_debug_info()
self.assertIn("num_owner_rrefs", info)
self.assertEqual(1, int(info["num_owner_rrefs"]))
# barrier after check 2
dist.barrier()
# clear states for check 2
rpc.rpc_sync("worker{}".format(dst_rank), clear_global_rref)
# Check 3: rpc.remote call should update owners_ map
####################################################
rref2 = rpc.remote(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(2, 2), 1)
)
rref3 = rpc.remote(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(2, 2), 1)
)
rref2.to_here()
rref3.to_here()
# barrier before check 3
dist.barrier()
info = _rref_context_get_debug_info()
self.assertIn("num_owner_rrefs", info)
self.assertEqual(2, int(info["num_owner_rrefs"]))
# barrier after check 3
dist.barrier()
@unittest.skip("Test is flaky, see https://github.com/pytorch/pytorch/issues/31112")
@dist_init
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
def test_process_group_debug_info(self):
from torch.distributed.rpc.api import _agent
NUM_THREAD = self.rpc_backend_options.num_send_recv_threads
info = _agent.get_debug_info()
self.assertIn("num_pending_requests", info)
self.assertIn("thread_pool_size", info)
self.assertIn("num_idle_threads", info)
self.assertEqual(int(info["num_pending_requests"]), 0)
self.assertEqual(int(info["thread_pool_size"]), NUM_THREAD)
self.assertEqual(int(info["num_idle_threads"]), NUM_THREAD)
dst_rank = (self.rank + 1) % self.world_size
fut = rpc.rpc_async(
"worker{}".format(dst_rank),
set_and_check_done,
args=(dst_rank,)
)
# blocks until the request arrives
self.assertEqual(self.rank, VALUE_FUTURE.result())
info = _agent.get_debug_info()
self.assertIn("num_pending_requests", info)
self.assertIn("thread_pool_size", info)
self.assertIn("num_idle_threads", info)
self.assertEqual(int(info["num_pending_requests"]), 1)
self.assertEqual(int(info["thread_pool_size"]), NUM_THREAD)
num_idle_threads = int(info["num_idle_threads"])
# as we cannot know for sure whether the send thread has returned, there
# might be either 1 or 2 busy threads
self.assertTrue(num_idle_threads in [NUM_THREAD - 1, NUM_THREAD - 2])
if not dist.is_initialized():
dist.init_process_group(
backend="gloo",
init_method=self.init_method,
rank=self.rank,
world_size=self.world_size,
)
# add a barrier to make sure the request is not finished before checking
# num_pending_requests
dist.barrier()
DONE_FUTURE.set_result(self.rank)
self.assertEqual(dst_rank, fut.wait())
# add a barrier to make sure the dst_rank has finished processing the
# request
dist.barrier()
info = _agent.get_debug_info()
self.assertIn("num_pending_requests", info)
self.assertIn("thread_pool_size", info)
self.assertIn("num_idle_threads", info)
self.assertEqual(int(info["num_pending_requests"]), 0)
self.assertEqual(int(info["thread_pool_size"]), NUM_THREAD)
for retry in range(3):
# even if the future has completed, there is no guarantee that
# the local send/recv threads would have finished. We try three
# times. (NB: this might potentially be flaky. If flakiness does
# occur, then we have to relax the assert.)
info = _agent.get_debug_info()
if int(info["num_idle_threads"]) == NUM_THREAD:
break
time.sleep(0.1)
self.assertEqual(int(info["num_idle_threads"]), NUM_THREAD)
# add a barrier to make sure SHUTDOWN message is not sent
dist.barrier()
@dist_init(setup_rpc=False)
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
def test_local_shutdown(self):
# test that we can start RPC and then immediately locally shutdown
# without sending any messages.
rpc.init_rpc(
name="worker%d" % self.rank,
backend=rpc.backend_registry.BackendType[
dist_utils.TEST_CONFIG.rpc_backend_name
],
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=self.rpc_backend_options,
)
# pass in graceful=False to ensure that we don't wait for other workers.
rpc.shutdown(graceful=False)
@dist_init
def test_debug_info(self):
# only test keys in this test case. Values should be covered by
# individual module debug info tests
from torch.distributed.rpc import (
_get_debug_info,
_rref_context_get_debug_info
)
from torch.distributed.rpc.api import _agent
import torch.distributed.autograd as dist_autograd
info = _get_debug_info()
rref_info = _rref_context_get_debug_info()
agent_info = _agent.get_debug_info()
autograd_info = dist_autograd._get_debug_info()
common_keys = rref_info.keys() & agent_info.keys() & autograd_info.keys()
self.assertEqual(0, len(common_keys))
expected = {}
expected.update(rref_info)
expected.update(agent_info)
expected.update(autograd_info)
self.assertEqual(expected, info)
@dist_init(setup_rpc=False)
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
def test_local_shutdown_with_rpc(self):
# test that we can start RPC, send RPCs, and then run local shutdown.
rpc.init_rpc(
name="worker%d" % self.rank,
backend=rpc.backend_registry.BackendType[
dist_utils.TEST_CONFIG.rpc_backend_name
],
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=self.rpc_backend_options,
)
n = self.rank + 1
dst_rank = n % self.world_size
rpc.rpc_sync(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
)
# A barrier is needed to ensure that all RPCs are processed.
# Otherwise, some RPCs can timeout since the receiving end
# has terminated.
if not dist.is_initialized():
dist.init_process_group(
backend="gloo",
init_method=self.init_method,
rank=self.rank,
world_size=self.world_size,
)
dist.barrier()
# pass in graceful=False to ensure that we don't wait for other workers.
rpc.shutdown(graceful=False)
@dist_init(setup_rpc=False)
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
def test_wait_all_workers_and_shutdown(self):
# This tests ensures that both rpc._wait_all_workers() and rpc.shutdown() can be
# called without errors being raised due to attempting to shut down
# multiple times.
rpc.init_rpc(
name="worker%d" % self.rank,
backend=rpc.backend_registry.BackendType[dist_utils.TEST_CONFIG.rpc_backend_name],
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=self.rpc_backend_options
)
from torch.distributed.rpc.api import _wait_all_workers
# intentional call to internal _wait_all_workers.
_wait_all_workers()
rpc.shutdown()
@dist_init(setup_rpc=False)
def test_get_rpc_timeout(self):
timeout = timedelta(seconds=1)
# A new `RpcBackendOptions` is constructed
# when accessing `self.rpc_backend_options`.
rpc_backend_options = self.rpc_backend_options
rpc_backend_options.rpc_timeout = timeout
rpc.init_rpc(
name="worker{}".format(self.rank),
backend=self.rpc_backend,
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=rpc_backend_options,
)
set_timeout = rpc.get_rpc_timeout()
self.assertEqual(timeout, set_timeout)
rpc.shutdown()
@dist_init
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
def test_rpc_timeouts(self):
dst_rank = (self.rank + 1) % self.world_size
rpc._set_rpc_timeout(timedelta(milliseconds=1))
# futures should time out and be marked with an exception indicating it as such.
futs = [rpc.rpc_async("worker{}".format(dst_rank), my_sleep_func, args=()) for _ in range(10)]
for fut in futs:
with self.assertRaisesRegex(RuntimeError, "RPC ran for more than"):
fut.wait()
# ensure that if a new timeout is set old futures don't time out but new ones do.
rpc._set_rpc_timeout(timedelta(seconds=200))
# create a longstanding RPC.
fut1 = rpc.rpc_async("worker{}".format(dst_rank), my_sleep_func, args=(1,))
# now, set a short timeout.
rpc._set_rpc_timeout(timedelta(milliseconds=1))
# f2 should time out, f should not.
fut2 = rpc.rpc_async("worker{}".format(dst_rank), my_sleep_func, args=(1,))
with self.assertRaises(RuntimeError):
fut2.wait()
fut1.wait()
# future should run to completion if the timeout is zero.
rpc._set_rpc_timeout(timedelta(seconds=0))
rpc.rpc_async("worker{}".format(dst_rank), my_sleep_func, args=()).wait()
# reset to default timeout so shutdown messages can process cleanly.
rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT)
def test_requires_process_group_agent_decorator(self):
@requires_process_group_agent("test_func did not run")
def test_func():
return "expected result"
if dist_utils.TEST_CONFIG.rpc_backend_name == "PROCESS_GROUP":
self.assertEqual(test_func(), "expected result")
def test_dist_init_decorator(self):
@dist_init(setup_rpc=False)
def test_func(self):
return "expected result"
self.assertEqual(test_func(self), "expected result")
@dist_init
def test_func(self):
return "expected result"
self.assertEqual(test_func(self), "expected result")
def test_use_rpc_pickler(self):
class TestPickler():
pass
test_pickler = TestPickler()
with _use_rpc_pickler(test_pickler):
self.assertTrue(torch.distributed.rpc.api._default_pickler is test_pickler)
self.assertTrue(torch.distributed.rpc.api._default_pickler is _internal_rpc_pickler)