blob: fd95eb10febdca6dd4be970f6479f0801e6e039f [file] [log] [blame]
from __future__ import absolute_import, division, print_function, unicode_literals
from functools import wraps
from os import getenv
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch.distributed.rpc.api import RpcBackend
if not dist.is_available():
print("c10d not available, skipping tests")
sys.exit(0)
class TestConfig:
__slots__ = ['rpc_backend']
def __init__(self, *args, **kwargs):
assert len(args) == 0, "TestConfig only takes kwargs."
for k, v in kwargs.items():
setattr(self, k, v)
TEST_CONFIG = TestConfig(rpc_backend=getenv("RPC_BACKEND", RpcBackend.PROCESS_GROUP))
INIT_METHOD_TEMPLATE = "file://{file_name}?rank={rank}&world_size={world_size}"
def dist_init(test_method):
"""
We use this decorator for setting up and tearing down state since
MultiProcessTestCase runs each `test*` method in a separate process and
each process just runs the `test*` method without actually calling
'setUp' and 'tearDown' methods of unittest.
"""
@wraps(test_method)
def wrapper(self, *arg, **kwargs):
self.worker_id = self.rank
dist.init_process_group(backend="gloo", init_method=self.init_method)
# Use enough 'num_send_recv_threads' until we fix https://github.com/pytorch/pytorch/issues/26359
rpc.init_model_parallel(
self_name="worker%d" % self.rank,
backend=TEST_CONFIG.rpc_backend,
self_rank=self.rank,
init_method=self.init_method,
num_send_recv_threads=16
)
test_method(self, *arg, **kwargs)
rpc.join_rpc()
return wrapper