blob: e3fa15ee6e5996cc4e8827bb7de5e880cb55a6b8 [file] [log] [blame]
from . import invoke_rpc_builtin, invoke_rpc_python_udf
from . import invoke_remote_builtin, invoke_remote_python_udf
from . import _start_rpc_agent
from . import _destroy_rref_context, _cleanup_python_rpc_handler
from . import WorkerInfo
from . import backend_registry
from .constants import DEFAULT_RPC_TIMEOUT, DEFAULT_NUM_SEND_RECV_THREADS
from .internal import _internal_rpc_pickler, PythonUDF
import functools
import sys
import torch
_agent = None
def _require_initialized(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if _agent is None:
raise RuntimeError(
"RPC has not been initialized. Call "
"torch.distributed.rpc.init_model_parallel first."
)
return func(*args, **kwargs)
return wrapper
def join_rpc():
r"""
Block until all local and remote RPC processes reach this method, process
(send and receive) all pending messages, and then destroy local RPC agent.
Every RPC process must call this method before exit.
"""
global _agent
if _agent:
_agent.join()
_agent = None
_destroy_rref_context()
# clean up python rpc handler in join_rpc(), see comments in
# PythonRpcHandler::cleanup(), call it in python API because the
# cleanup() function has python dependency, it assumes python
# interpreter exists
_cleanup_python_rpc_handler()
@_require_initialized
def sync_rpc():
r"""
Block until all local and remote RPC processes reach this method and finish
sending all pending RPCs. As this method synchronizes at the process
level, if multiple threads are spawned, only one of them should call this
method at a time.
"""
_agent.sync()
# TODO: add a context manager to wrap _init_rpc and join_rpc
def _init_rpc(
backend=backend_registry.BackendType.PROCESS_GROUP,
store=None,
self_name=None,
self_rank=-1,
worker_name_to_id=None,
num_send_recv_threads=DEFAULT_NUM_SEND_RECV_THREADS,
rpc_timeout=DEFAULT_RPC_TIMEOUT,
):
if sys.version_info < (3, 0):
raise RuntimeError("RPC package does not support Python2.")
global _agent
if _agent:
raise RuntimeError("RPC is already initialized")
# Initialize RPC.
_agent = backend_registry.init_backend(
backend,
store=store,
self_name=self_name,
self_rank=self_rank,
worker_name_to_id=worker_name_to_id,
num_send_recv_threads=num_send_recv_threads,
rpc_timeout=rpc_timeout,
)
_start_rpc_agent(_agent)
@_require_initialized
def get_worker_info(worker_name=None):
r"""
Get ``WorkerInfo`` of a given worker name. Use this ``WorkerInfo`` to avoid
passing an expensive string on every invocation. The ``WorkerInfo`` contains
the name and the id of the worker.
Arguments:
worker_name (str): the string name of a worker. If ``None``, return the
the id of the current worker. (default ``None``)
Returns:
``WorkerInfo`` instance for the given ``worker_name`` or ``WorkerInfo``
of the current worker if ``worker_name`` is ``None``.
"""
if worker_name:
return _agent.get_worker_info(worker_name)
else:
return _agent.get_worker_info()
@_require_initialized
def get_rpc_timeout():
"""
Retrieve the timeout for all RPCs that was set during RPC initialization.
Returns:
`datetime.timedelta` instance indicating the RPC timeout.
"""
return _agent._get_rpc_timeout()
def _to_worker_info(name_or_info):
if isinstance(name_or_info, WorkerInfo):
return name_or_info
elif isinstance(name_or_info, str):
return get_worker_info(name_or_info)
else:
raise ValueError("Cannot get WorkerInfo from name".format(name_or_info))
@_require_initialized
def remote(to, func, args=None, kwargs=None):
r"""
Make a ``remote`` call to run ``func`` on worker ``to``, and returns an
``RRef`` to the result value immediately. Worker ``to`` will be the owner
of the returned ``RRef``, and this worker is a user. The owner manages the
global reference count of its ``RRef``s, and the owner ``RRef`` is only
destructed when globally there is no living references to it.
Arguments:
to (int or str): id or name of the destination worker.
func (callable): builtin functions (like ``torch.add``).
args (tuple): the argument tuple for the ``func`` invocation.
kwargs (dict): is a dictionary of keyword arguments for the ``func``
invocation.
Returns:
A user ``RRef`` instance to the result value. Use the blocking API
``RRef.to_here()`` to retrieve the result value locally.
Example::
On worker 0:
>>> import torch.distributed as dist
>>> import torch.distributed.rpc as rpc
>>> dist.init_process_group(backend='gloo', rank=0, world_size=2)
>>> rpc.init_model_parallel("worker0")
>>> worker1 = rpc.get_worker_info("worker1")
>>> rref1 = rpc.remote(worker1, torch.add, args=(torch.ones(2), 3))
>>> rref2 = rpc.remote(worker1, torch.add, args=(torch.ones(2), 1))
>>> x = rref1.to_here() + rref2.to_here()
>>> rpc.join_rpc()
On worker 1:
>>> import torch.distributed as dist
>>> dist.init_process_group(backend='gloo', rank=1, world_size=2)
>>> dist.init_model_parallel("worker1")
>>> rpc.join_rpc()
"""
qualified_name = torch.jit._find_builtin(func)
args = args if args else ()
kwargs = kwargs if kwargs else {}
info = _to_worker_info(to)
if qualified_name is not None:
return invoke_remote_builtin(
_agent, info, qualified_name, *args, **kwargs)
else:
(pickled_python_udf, tensors) = _internal_rpc_pickler.serialize(
PythonUDF(func, args, kwargs))
return invoke_remote_python_udf(
_agent, info, pickled_python_udf, tensors)
def _invoke_rpc(to, func, args=None, kwargs=None):
if not callable(func):
raise TypeError("function should be callable.")
qualified_name = torch.jit._find_builtin(func)
args = args if args else ()
kwargs = kwargs if kwargs else {}
info = _to_worker_info(to)
if qualified_name is not None:
fut = invoke_rpc_builtin(
_agent, info, qualified_name, *args, **kwargs
)
else:
(pickled_python_udf, tensors) = _internal_rpc_pickler.serialize(
PythonUDF(func, args, kwargs))
fut = invoke_rpc_python_udf(
_agent, info, pickled_python_udf, tensors)
return fut
@_require_initialized
def rpc_sync(to, func, args=None, kwargs=None):
r"""
Make a blocking RPC call to run function ``func`` on worker ``to``. RPC
messages are sent and received in parallel to execution of Python code. This
method is thread-safe.
Arguments:
to (int or str): id or name of the destination worker.
func (callable): any callable function. builtin functions (like
``torch.add``) can be sent over RPC more efficiently.
args (tuple): the argument tuple for the ``func`` invocation.
kwargs (dict): is a dictionary of keyword arguments for the ``func``
invocation.
Returns:
Returns the result of running ``func``on ``args`` and ``kwargs``.
Example::
On worker 0:
>>> import torch.distributed as dist
>>> import torch.distributed.rpc as rpc
>>> dist.init_process_group(backend='gloo', rank=0, world_size=2)
>>> rpc.init_model_parallel("worker0")
>>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3))
>>> rpc.join_rpc()
On worker 1:
>>> import torch.distributed as dist
>>> import torch.distributed.rpc as rpc
>>> dist.init_process_group(backend='gloo', rank=1, world_size=2)
>>> rpc.init_model_parallel("worker1")
>>> rpc.join_rpc()
"""
fut = _invoke_rpc(to, func, args, kwargs)
return fut.wait()
@_require_initialized
def rpc_async(to, func, args=None, kwargs=None):
r"""
Make a non-blocking RPC call to run function ``func`` on worker ``to``. RPC
messages are sent and received in parallel to execution of Python code. This
method is thread-safe. This method will immediately return a
torch.distributed.FutureMessage that can be awaited on.
Arguments:
to (int or str): id or name of the destination worker.
func (callable): any callable function. builtin functions (like
``torch.add``) can be sent over RPC more efficiently.
args (tuple): the argument tuple for the ``func`` invocation.
kwargs (dict): is a dictionary of keyword arguments for the ``func``
invocation.
Returns:
Returns a ``torch.distributed.FutureMessage`` object that can be waited
on. When completed, the return value of ``func`` on ``args`` and
``kwargs`` can be retrieved from the ``FutureMessage`` object.
Example::
On worker 0:
>>> import torch.distributed as dist
>>> import torch.distributed.rpc as rpc
>>> dist.init_process_group(backend='gloo', rank=0, world_size=2)
>>> rpc.init_model_parallel("worker0")
>>> worker1 = rpc.get_worker_id("worker1")
>>> fut1 = rpc.rpc_async(worker1, torch.add, args=(torch.ones(2), 3))
>>> fut2 = rpc.rpc_async(worker1, min, args=(1, 2))
>>> result = fut1.wait() + fut2.wait()
>>> rpc.join_rpc()
On worker 1:
>>> import torch.distributed as dist
>>> import torch.distributed.rpc as rpc
>>> dist.init_process_group(backend='gloo', rank=1, world_size=2)
>>> rpc.init_model_parallel("worker1")
>>> rpc.join_rpc()
"""
fut = _invoke_rpc(to, func, args, kwargs)
return fut