blob: 679a860befd09ad3df4385eb8f082e5c6560686d [file] [log] [blame]
from __future__ import absolute_import, division, print_function, unicode_literals
import collections
import datetime
import enum
import torch.distributed as dist
import torch.distributed.distributed_c10d as dc10d
from . import constants as rpc_constants
BackendValue = collections.namedtuple(
"BackendValue", ["construct_rpc_backend_options_handler", "init_backend_handler"]
)
def _backend_type_repr(self):
return "BackendType." + self.name
# Create an enum type, `BackendType`, with empty members.
BackendType = enum.Enum(value="BackendType", names={})
BackendType.__repr__ = _backend_type_repr
def register_backend(
backend_name, construct_rpc_backend_options_handler, init_backend_handler
):
"""Registers a new RPC backend.
Arguments:
backend_name (str): backend string to identify the handler.
construct_rpc_backend_options_handler (function):
Handler that is invoked when
rpc_backend.construct_rpc_backend_options(**dict) is called.
init_backend_handler (function): Handler that is invoked when the
`_init_rpc_backend()` function is called with a backend.
This returns the agent.
"""
global BackendType
if backend_name in BackendType.__members__.keys():
raise RuntimeError("RPC backend {}: already registered".format(backend_name))
# Create a new enum type, `BackendType`, with extended members.
existing_enum_dict = {member.name: member.value for member in BackendType}
extended_enum_dict = dict(
{
backend_name: BackendValue(
construct_rpc_backend_options_handler=construct_rpc_backend_options_handler,
init_backend_handler=init_backend_handler,
)
},
**existing_enum_dict
)
BackendType = enum.Enum(value="BackendType", names=extended_enum_dict)
BackendType.__repr__ = _backend_type_repr
return BackendType[backend_name]
def construct_rpc_backend_options(
backend,
rpc_timeout=rpc_constants.DEFAULT_RPC_TIMEOUT,
init_method=rpc_constants.DEFAULT_INIT_METHOD,
**kwargs
):
if not isinstance(rpc_timeout, datetime.timedelta):
raise RuntimeError("`rpc_timeout` must be a `datetime.timedelta`.")
return backend.value.construct_rpc_backend_options_handler(
rpc_timeout, init_method, **kwargs
)
def init_backend(backend, *args, **kwargs):
return backend.value.init_backend_handler(*args, **kwargs)
def _process_group_construct_rpc_backend_options_handler(
rpc_timeout,
init_method,
num_send_recv_threads=rpc_constants.DEFAULT_NUM_SEND_RECV_THREADS,
**kwargs
):
from . import ProcessGroupRpcBackendOptions
return ProcessGroupRpcBackendOptions(
rpc_timeout=rpc_timeout,
init_method=init_method,
num_send_recv_threads=num_send_recv_threads
)
def _process_group_init_backend_handler(
store, name, rank, world_size, rpc_backend_options
):
from . import ProcessGroupAgent
# Initialize ProcessGroup.
if dist.is_initialized():
raise RuntimeError(
"Default process group must not be initialized before init_rpc."
)
process_group_timeout = rpc_constants.DEFAULT_PROCESS_GROUP_TIMEOUT
dist.init_process_group(
backend=dist.Backend.GLOO,
store=store,
rank=rank,
world_size=world_size,
timeout=process_group_timeout,
)
try:
group = dc10d._get_default_group()
assert group is not None, "Failed to initialize default ProcessGroup."
if (rank != -1) and (rank != group.rank()):
raise RuntimeError(
"rank argument {} doesn't match pg rank {}".format(rank, group.rank())
)
if (world_size != -1) and (world_size != group.size()):
raise RuntimeError(
"world_size argument {} doesn't match pg size {}".format(
world_size, group.size()
)
)
# TODO: add try-except and destroy _agent in all processes if any fails.
return ProcessGroupAgent(
name,
group,
rpc_backend_options.num_send_recv_threads,
rpc_backend_options.rpc_timeout,
)
except Exception as ex:
dist.destroy_process_group()
raise ex
register_backend(
"PROCESS_GROUP",
_process_group_construct_rpc_backend_options_handler,
_process_group_init_backend_handler,
)