blob: 4a65776d67725e4e42cae9a05623dbb31dc35048 [file] [log] [blame]
#!/usr/bin/env python3
from __future__ import absolute_import, division, print_function, unicode_literals
import sys
import unittest
import torch
import torch.distributed as dist
from common_distributed import MultiProcessTestCase
from common_utils import load_tests, run_tests
# it is used to test python user defined function over rpc
def my_function(a, b, c):
return a + b + c
# it is used to test python user defined function over rpc
def no_result():
print("do nothing")
def nested_rpc(dst):
return dist.rpc(dst, torch.add, args=(torch.ones(2, 2), 1))
def light_rpc():
return 0
def heavy_rpc(tensor):
for i in range(1, 100):
tensor *= i
tensor /= i + 1
return 0
# it is used to test python user defined function over rpc
def raise_func():
raise ValueError("Expected error")
# it is used to test python user defined class and methods over rpc
class my_class:
def __init__(self, a):
self.a = a
def my_instance_method(self, b):
return self.a + b
@classmethod
def my_class_method(cls, d, e):
return d + e
@staticmethod
def my_static_method(f):
return f > 10
# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests
if not dist.is_available():
print("c10d not available, skipping tests")
sys.exit(0)
def _wrap_with_rpc(func):
'''
We use this decorator for setting up and tearing down state since
MultiProcessTestCase runs each `test*` method in a separate process and
each process just runs the `test*` method without actually calling
'setUp' and 'tearDown' methods of unittest.
'''
def wrapper(self):
store = dist.FileStore(self.file.name, self.world_size)
dist.init_process_group(backend='gloo', rank=self.rank,
world_size=self.world_size, store=store)
dist.init_model_parallel('worker%d' % self.rank)
func(self)
dist.join_rpc()
return wrapper
@unittest.skipIf(
sys.version_info < (3, 0),
"Pytorch distributed rpc package " "does not support python2",
)
class RpcTest(MultiProcessTestCase):
@property
def world_size(self):
return 4
@_wrap_with_rpc
def test_worker_id(self):
n = self.rank + 1
peer_rank = n % self.world_size
self_worker_id = dist.get_worker_id()
peer_worker_id = dist.get_worker_id('worker{}'.format(peer_rank))
self.assertEqual(self_worker_id.name, 'worker{}'.format(self.rank))
self.assertEqual(peer_worker_id.name, 'worker{}'.format(peer_rank))
with self.assertRaisesRegex(RuntimeError, "Unknown destination worker"):
unknown_worker_id = dist.get_worker_id("WorkerUnknown")
@_wrap_with_rpc
def test_self_add(self):
self_worker_id = dist.get_worker_id()
self_worker_name = 'worker{}'.format(self.rank)
with self.assertRaisesRegex(
RuntimeError, "does not support making RPC calls to self"
):
dist.rpc(self_worker_id, torch.add, args=(torch.ones(2, 2), 1))
with self.assertRaisesRegex(
RuntimeError, "does not support making RPC calls to self"
):
dist.rpc(self_worker_name, torch.add, args=(torch.ones(2, 2), 1))
def test_duplicated_names(self):
store = dist.FileStore(self.file.name, self.world_size)
dist.init_process_group(backend="gloo", rank=self.rank,
world_size=self.world_size, store=store)
with self.assertRaisesRegex(RuntimeError, "is not unique"):
dist.init_model_parallel("duplicated_name")
dist.join_rpc()
def test_invalid_names(self):
store = dist.FileStore(self.file.name, self.world_size)
dist.init_process_group(backend="gloo", rank=self.rank,
world_size=self.world_size, store=store)
with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
dist.init_model_parallel("abc*")
with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
dist.init_model_parallel(" ")
with self.assertRaisesRegex(RuntimeError, "must be non-empty"):
dist.init_model_parallel("")
# If the number in the message does not match, it is likely that the
# value of MAX_NAME_LEN in RPC WorkerId has changed.
with self.assertRaisesRegex(RuntimeError, "shorter than 128"):
dist.init_model_parallel("".join(["a" for _ in range(500)]))
dist.join_rpc()
@_wrap_with_rpc
def test_add(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = dist.rpc(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
)
self.assertEqual(ret, torch.ones(n, n) * 2)
@_wrap_with_rpc
def test_add_with_id(self):
n = self.rank + 1
dst_rank = n % self.world_size
workder_id = dist.get_worker_id('worker{}'.format(dst_rank))
ret = dist.rpc(workder_id, torch.add,
args=(torch.ones(n, n), torch.ones(n, n)))
self.assertEqual(ret, torch.ones(n, n) * 2)
@_wrap_with_rpc
def test_scalar_add(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = dist.rpc(
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), n)
)
self.assertEqual(ret, (torch.ones(n, n) + n))
@_wrap_with_rpc
def test_async_add(self):
n = self.rank + 1
dst_rank = n % self.world_size
fut = dist.rpc(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
async_call=True,
)
self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
@_wrap_with_rpc
def test_nonzero(self):
n = self.rank + 1
dst_rank = n % self.world_size
x = torch.ones(self.world_size, self.world_size)
x[self.rank][self.rank] = 0
ret = dist.rpc("worker{}".format(dst_rank), torch.nonzero, args=(x,))
self.assertEqual(ret, x.nonzero())
@_wrap_with_rpc
def test_multi_rpc(self):
dst_rank = (self.rank + 1) % self.world_size
for i in range(20):
n = i + self.rank + 1
ret = dist.rpc(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
)
self.assertEqual(ret, torch.ones(n, n) * 2)
@_wrap_with_rpc
def test_sync_rpc(self):
dst_rank = (self.rank + 1) % self.world_size
for i in range(20):
dist.sync_rpc()
n = i + self.rank + 1
ret1 = dist.rpc(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
)
dist.sync_rpc()
ret2 = dist.rpc(
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 2)
)
dist.sync_rpc()
self.assertEqual(ret1, torch.ones(n, n) * 2)
self.assertEqual(ret2, torch.ones(n, n) * 3)
@_wrap_with_rpc
def test_join_rpc(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = dist.rpc(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
)
self.assertEqual(ret, torch.ones(n, n) * 2)
dist.join_rpc()
with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"):
dist.rpc(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
)
# it's safe to call join_rpc() multiple times
dist.join_rpc()
@_wrap_with_rpc
def test_py_built_in(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = dist.rpc("worker{}".format(dst_rank), min, args=(n, n + 1, n + 2))
self.assertEqual(ret, min(n, n + 1, n + 2))
@_wrap_with_rpc
def test_py_user_defined(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = dist.rpc(
"worker{}".format(dst_rank),
my_function,
kwargs={"a": n, "b": n + 1, "c": n + 2},
)
self.assertEqual(ret, my_function(n, n + 1, n + 2))
@_wrap_with_rpc
def test_py_class_constructor(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = dist.rpc("worker{}".format(dst_rank), my_class, args=(n,))
self.assertEqual(ret.a, n)
@_wrap_with_rpc
def test_py_class_instance_method(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = dist.rpc(
"worker{}".format(dst_rank), my_class(2).my_instance_method, args=(n,)
)
self.assertEqual(ret, my_class(2).my_instance_method(n))
@_wrap_with_rpc
def test_py_class_method(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = dist.rpc(
"worker{}".format(dst_rank), my_class.my_class_method, args=(n, n + 1)
)
self.assertEqual(ret, my_class.my_class_method(n, n + 1))
@_wrap_with_rpc
def test_py_class_static_method(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = dist.rpc(
"worker{}".format(dst_rank), my_class.my_static_method, args=(n + 10,)
)
self.assertEqual(ret, my_class.my_static_method(n + 10))
@_wrap_with_rpc
def test_py_multi_async_call(self):
n = self.rank + 1
dst_rank = n % self.world_size
dst_worker_id = dist.get_worker_id('worker{}'.format(dst_rank))
fut1 = dist.rpc(dst_worker_id,
my_class.my_static_method,
args=(n + 10,),
async_call=True)
fut2 = dist.rpc(dst_worker_id,
min,
args=(n, n + 1, n + 2),
async_call=True)
self.assertEqual(fut1.wait(), my_class.my_static_method(n + 10))
self.assertEqual(fut2.wait(), min(n, n + 1, n + 2))
@_wrap_with_rpc
def test_py_no_return_result(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = dist.rpc("worker{}".format(dst_rank), no_result)
self.assertEqual(ret, no_result())
@_wrap_with_rpc
def test_py_function_exception(self):
n = self.rank + 1
dst_rank = n % self.world_size
with self.assertRaisesRegex(Exception, "TypeError"):
ret = dist.rpc("worker{}".format(dst_rank), no_result, args=(10,))
@_wrap_with_rpc
def test_py_raise_in_user_func(self):
n = self.rank + 1
dst_rank = n % self.world_size
fut = dist.rpc("worker{}".format(dst_rank), raise_func, async_call=True)
with self.assertRaisesRegex(Exception, "ValueError"):
fut.wait()
@_wrap_with_rpc
def test_nested_rpc(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = dist.rpc(
"worker{}".format(dst_rank),
nested_rpc,
args=("worker{}".format(self.rank),),
)
self.assertEqual(ret, torch.ones(2, 2) + 1)
def _stress_test_rpc(self, f, repeat=1000, args=()):
import time
n = self.rank + 1
dst_rank = n % self.world_size
futs = []
tik = time.time()
for _ in range(repeat):
fut = dist.rpc("worker{}".format(dst_rank), f, args=args, async_call=True)
futs.append(fut)
for fut in futs:
self.assertEqual(fut.wait(), 0)
tok = time.time()
print(
"Rank {} finished testing {} {} times in {} seconds.".format(
self.rank, f.__name__, repeat, tok - tik
)
)
@_wrap_with_rpc
def test_stress_light_rpc(self):
self._stress_test_rpc(light_rpc)
@_wrap_with_rpc
def test_stress_heavy_rpc(self):
self._stress_test_rpc(heavy_rpc, repeat=20, args=(torch.ones(100, 100),))
@_wrap_with_rpc
def test_builtin_remote_ret(self):
n = self.rank + 1
dst_rank = n % self.world_size
rref = dist.remote('worker{}'.format(dst_rank), torch.add,
args=(torch.ones(n, n), torch.ones(n, n)))
self.assertEqual(rref.to_here(), torch.ones(n, n) * 2)
@_wrap_with_rpc
def test_multi_builtin_remote_ret(self):
m = 10
n = self.rank + 1
dst_rank = n % self.world_size
rrefs = []
expected = []
for i in range(m):
n = n + i
rrefs.append(dist.remote(
'worker{}'.format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n))
))
expected.append(torch.ones(n, n) * 2)
for i in range(m):
self.assertEqual(rrefs[i].to_here(), expected[i])
if __name__ == '__main__':
run_tests()