| #include <torch/csrc/python_headers.h> |
| |
| #include <torch/csrc/distributed/rpc/process_group_agent.h> |
| #include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h> |
| #include <torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h> |
| #include <torch/csrc/distributed/rpc/py_rref.h> |
| #include <torch/csrc/distributed/rpc/python_functions.h> |
| #include <torch/csrc/distributed/rpc/python_rpc_handler.h> |
| #include <torch/csrc/distributed/rpc/request_callback_impl.h> |
| #include <torch/csrc/distributed/rpc/rpc_agent.h> |
| #include <torch/csrc/distributed/rpc/rref_context.h> |
| #include <torch/csrc/distributed/rpc/tensorpipe_agent.h> |
| #include <torch/csrc/distributed/rpc/torchscript_functions.h> |
| #include <torch/csrc/distributed/rpc/types.h> |
| #include <torch/csrc/jit/python/pybind_utils.h> |
| #include <torch/csrc/utils/object_ptr.h> |
| #include <torch/csrc/utils/pybind.h> |
| #include <torch/types.h> |
| |
| #include <pybind11/chrono.h> |
| #include <pybind11/operators.h> |
| |
| namespace torch { |
| namespace distributed { |
| namespace rpc { |
| |
| namespace { |
| |
| constexpr std::chrono::milliseconds kDeleteAllUsersTimeout(100000); |
| |
| template <typename T> |
| using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>; |
| |
| PyObject* rpc_init(PyObject* _unused, PyObject* noargs) { |
| auto rpc_module = |
| THPObjectPtr(PyImport_ImportModule("torch.distributed.rpc")); |
| if (!rpc_module) { |
| throw python_error(); |
| } |
| |
| auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C")); |
| if (!torch_C_module) { |
| throw python_error(); |
| } |
| |
| auto torch_C_m = py::handle(torch_C_module).cast<py::module>(); |
| auto m = |
| torch_C_m.def_submodule("_distributed_rpc", "distributed rpc bindings"); |
| |
| auto module = py::handle(m).cast<py::module>(); |
| |
| auto rpcBackendOptions = |
| shared_ptr_class_<RpcBackendOptions>( |
| module, |
| "RpcBackendOptions", |
| R"(An abstract structure encapsulating the options passed into the RPC |
| backend. An instance of this class can be passed in to |
| :meth:`~torch.distributed.rpc.init_rpc` in order to initialize RPC |
| with specific configurations, such as the RPC timeout and |
| ``init_method`` to be used. )") |
| .def(py::init<>()) |
| .def( |
| py::init<float, std::string>(), |
| py::arg("rpc_timeout") = kDefaultRpcTimeoutSeconds, |
| py::arg("init_method") = kDefaultInitMethod) |
| .def_readwrite( |
| "rpc_timeout", |
| &RpcBackendOptions::rpcTimeoutSeconds, |
| R"(A float indicating the timeout to use for all |
| RPCs. If an RPC does not complete in this timeframe, it will |
| complete with an exception indicating that it has timed out.)") |
| .def_readwrite( |
| "init_method", |
| &RpcBackendOptions::initMethod, |
| R"(URL specifying how to initialize the process group. |
| Default is ``env://``)"); |
| |
| // The following C++ constants need to be cast so they can be used from |
| // python. |
| module.attr("_DEFAULT_RPC_TIMEOUT_SEC") = py::cast(kDefaultRpcTimeoutSeconds); |
| module.attr("_UNSET_RPC_TIMEOUT") = py::cast(kUnsetRpcTimeout); |
| module.attr("_DEFAULT_INIT_METHOD") = py::cast(kDefaultInitMethod); |
| |
| auto workerInfo = |
| shared_ptr_class_<WorkerInfo>( |
| module, |
| "WorkerInfo", |
| R"(A structure that encapsulates information of a worker in the system. |
| Contains the name and ID of the worker. This class is not meant to |
| be constructed directly, rather, an instance can be retrieved |
| through :meth:`~torch.distributed.rpc.get_worker_info` and the |
| result can be passed in to functions such as |
| :meth:`~torch.distributed.rpc.rpc_sync`, :meth:`~torch.distributed.rpc.rpc_async`, |
| :meth:`~torch.distributed.rpc.remote` to avoid copying a string on |
| every invocation.)") |
| .def( |
| py::init<std::string, worker_id_t>(), |
| py::arg("name"), |
| py::arg("id")) |
| .def_readonly( |
| "name", &WorkerInfo::name_, R"(The name of the worker.)") |
| .def_readonly( |
| "id", |
| &WorkerInfo::id_, |
| R"(Globally unique id to identify the worker.)") |
| .def("__eq__", &WorkerInfo::operator==, py::is_operator()) |
| // pybind11 suggests the syntax .def(hash(py::self)), with the |
| // unqualified "hash" function call. However the |
| // argument-dependent lookup for the function "hash" doesn't get |
| // triggered in this context because it conflicts with the struct |
| // c10::hash, so we need to use the qualified name |
| // py::detail::hash, which unfortunately is in a detail namespace. |
| .def(py::detail::hash(py::self)) // NOLINT |
| .def("__repr__", [](const WorkerInfo& workerInfo) { |
| std::ostringstream os; |
| os << workerInfo; |
| return os.str(); |
| }); |
| |
| auto rpcAgent = |
| shared_ptr_class_<RpcAgent>(module, "RpcAgent") |
| .def( |
| "join", &RpcAgent::join, py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "sync", &RpcAgent::sync, py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "shutdown", |
| &RpcAgent::shutdown, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "get_worker_info", |
| (const WorkerInfo& (RpcAgent::*)(void) const) & |
| RpcAgent::getWorkerInfo, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "get_worker_info", |
| (const WorkerInfo& (RpcAgent::*)(const std::string&) const) & |
| RpcAgent::getWorkerInfo, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "get_worker_infos", |
| &RpcAgent::getWorkerInfos, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "get_debug_info", |
| &RpcAgent::getDebugInfo, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "get_metrics", |
| &RpcAgent::getMetrics, |
| py::call_guard<py::gil_scoped_release>()); |
| |
| auto pyRRef = |
| shared_ptr_class_<PyRRef>(module, "PyRRef", R"( |
| A class encapsulating a reference to a value of some type on a remote |
| worker. This handle will keep the referenced remote value alive on the |
| worker. A ``UserRRef`` will be deleted when 1) no references to it in |
| both the application code and in the local RRef context, or 2) the |
| application has called a graceful shutdown. Invoking methods on a |
| deleted RRef leads to undefined behaviors. RRef implementation only |
| offers best-effort error detection, and applications should not use |
| ``UserRRefs`` after ``rpc.shutdown()``. |
| |
| .. warning:: |
| RRefs can only be serialized and deserialized by the RPC module. |
| Serializing and deserializing RRefs without RPC (e.g., Python |
| pickle, torch :meth:`~torch.save` / :meth:`~torch.load`, |
| JIT :meth:`~torch.jit.save` / :meth:`~torch.jit.load`, etc.) will |
| lead to errors. |
| |
| Example:: |
| Following examples skip RPC initialization and shutdown code |
| for simplicity. Refer to RPC docs for those details. |
| |
| 1. Create an RRef using rpc.remote |
| |
| >>> import torch |
| >>> import torch.distributed.rpc as rpc |
| >>> rref = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) |
| >>> # get a copy of value from the RRef |
| >>> x = rref.to_here() |
| |
| 2. Create an RRef from a local object |
| |
| >>> import torch |
| >>> from torch.distributed.rpc import RRef |
| >>> x = torch.zeros(2, 2) |
| >>> rref = RRef(x) |
| |
| 3. Share an RRef with other workers |
| |
| >>> # On both worker0 and worker1: |
| >>> def f(rref): |
| >>> return rref.to_here() + 1 |
| |
| >>> # On worker0: |
| >>> import torch |
| >>> import torch.distributed.rpc as rpc |
| >>> from torch.distributed.rpc import RRef |
| >>> rref = RRef(torch.zeros(2, 2)) |
| >>> # the following RPC shares the rref with worker1, reference |
| >>> # count is automatically updated. |
| >>> rpc.rpc_sync("worker1", f, args=(rref,)) |
| )") |
| .def( |
| py::init<const py::object&, const py::object&>(), |
| py::arg("value"), |
| py::arg("type_hint") = py::none()) |
| .def( |
| // not releasing GIL here to avoid context switch on getters |
| "is_owner", |
| &PyRRef::isOwner, |
| R"( |
| Returns whether or not the current node is the owner of this |
| ``RRef``. |
| )") |
| .def( |
| "confirmed_by_owner", |
| &PyRRef::confirmedByOwner, |
| R"( |
| Returns whether this ``RRef`` has been confirmed by the owner. |
| ``OwnerRRef`` always returns true, while ``UserRRef`` only |
| returns true when the owner knowns about this ``UserRRef``. |
| )") |
| .def( |
| // not releasing GIL here to avoid context switch on getters |
| "owner", |
| &PyRRef::owner, |
| R"( |
| Returns worker information of the node that owns this ``RRef``. |
| )") |
| .def( |
| // not releasing GIL here to avoid context switch on getters |
| "owner_name", |
| &PyRRef::ownerName, |
| R"( |
| Returns worker name of the node that owns this ``RRef``. |
| )") |
| .def( |
| "to_here", |
| &PyRRef::toHere, |
| py::arg("timeout") = py::cast(kUnsetRpcTimeout), |
| py::call_guard<py::gil_scoped_release>(), |
| R"( |
| Blocking call that copies the value of the RRef from the owner |
| to the local node and returns it. If the current node is the |
| owner, returns a reference to the local value. |
| |
| Args: |
| timeout (float, optional): Timeout for ``to_here``. If |
| the call does not complete within this timeframe, an |
| exception indicating so will be raised. If this |
| argument is not provided, the default RPC timeout |
| (60s) will be used. |
| )") |
| .def( |
| "local_value", |
| &PyRRef::localValue, |
| py::call_guard<py::gil_scoped_release>(), |
| R"( |
| If the current node is the owner, returns a reference to the |
| local value. Otherwise, throws an exception. |
| )") |
| .def( |
| "rpc_sync", |
| [](const PyRRef& self, float timeoutSeconds) { |
| return self.createRRefProxy( |
| RRefProxyType::RPC_SYNC, timeoutSeconds); |
| }, |
| py::arg("timeout") = kUnsetRpcTimeout, |
| py::call_guard<py::gil_scoped_release>(), |
| R"( |
| Create a helper proxy to easily launch an ``rpc_sync`` using |
| the owner of the RRef as the destination to run functions on |
| the object referenced by this RRef. More specifically, |
| ``rref.rpc_sync().func_name(*args, **kwargs)`` is the same as |
| the following: |
| |
| >>> def run(rref, func_name, args, kwargs): |
| >>> return getattr(rref.local_value(), func_name)(*args, **kwargs) |
| >>> |
| >>> rpc.rpc_sync(rref.owner(), run, args=(rref, func_name, args, kwargs)) |
| |
| Args: |
| timeout (float, optional): Timeout for ``rref.rpc_sync()``. |
| If the call does not complete within this timeframe, an |
| exception indicating so will be raised. If this argument |
| is not provided, the default RPC timeout will be used. |
| |
| Example:: |
| >>> from torch.distributed import rpc |
| >>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1)) |
| >>> rref.rpc_sync().size() # returns torch.Size([2, 2]) |
| >>> rref.rpc_sync().view(1, 4) # returns tensor([[1., 1., 1., 1.]]) |
| )") |
| .def( |
| "rpc_async", |
| [](const PyRRef& self, float timeoutSeconds) { |
| return self.createRRefProxy( |
| RRefProxyType::RPC_ASYNC, timeoutSeconds); |
| }, |
| py::arg("timeout") = kUnsetRpcTimeout, |
| py::call_guard<py::gil_scoped_release>(), |
| R"( |
| Create a helper proxy to easily launch an ``rpc_async`` using |
| the owner of the RRef as the destination to run functions on |
| the object referenced by this RRef. More specifically, |
| ``rref.rpc_async().func_name(*args, **kwargs)`` is the same as |
| the following: |
| |
| >>> def run(rref, func_name, args, kwargs): |
| >>> return getattr(rref.local_value(), func_name)(*args, **kwargs) |
| >>> |
| >>> rpc.rpc_async(rref.owner(), run, args=(rref, func_name, args, kwargs)) |
| |
| Args: |
| timeout (float, optional): Timeout for ``rref.rpc_async()``. |
| If the call does not complete within this timeframe, an |
| exception indicating so will be raised. If this argument |
| is not provided, the default RPC timeout will be used. |
| |
| Example:: |
| >>> from torch.distributed import rpc |
| >>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1)) |
| >>> rref.rpc_async().size().wait() # returns torch.Size([2, 2]) |
| >>> rref.rpc_async().view(1, 4).wait() # returns tensor([[1., 1., 1., 1.]]) |
| )") |
| .def( |
| "remote", |
| [](const PyRRef& self, float timeoutSeconds) { |
| return self.createRRefProxy( |
| RRefProxyType::REMOTE, timeoutSeconds); |
| }, |
| py::arg("timeout") = kUnsetRpcTimeout, |
| py::call_guard<py::gil_scoped_release>(), |
| R"( |
| Create a helper proxy to easily launch a ``remote`` using |
| the owner of the RRef as the destination to run functions on |
| the object referenced by this RRef. More specifically, |
| ``rref.remote().func_name(*args, **kwargs)`` is the same as |
| the following: |
| |
| >>> def run(rref, func_name, args, kwargs): |
| >>> return getattr(rref.local_value(), func_name)(*args, **kwargs) |
| >>> |
| >>> rpc.remote(rref.owner(), run, args=(rref, func_name, args, kwargs)) |
| |
| Args: |
| timeout (float, optional): Timeout for ``rref.remote()``. If |
| the creation of this :class:`~torch.distributed.rpc.RRef` |
| is not successfully completed within the timeout, then the |
| next time there is an attempt to use the RRef |
| (such as ``to_here``), a timeout will be raised. If not |
| provided, the default RPC timeout will be used. Please see |
| ``rpc.remote()`` for specific timeout semantics for |
| :class:`~torch.distributed.rpc.RRef`. |
| |
| Example:: |
| >>> from torch.distributed import rpc |
| >>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1)) |
| >>> rref.remote().size().to_here() # returns torch.Size([2, 2]) |
| >>> rref.remote().view(1, 4).to_here() # returns tensor([[1., 1., 1., 1.]]) |
| )") |
| .def( |
| py::pickle( |
| /* __getstate__ */ |
| [](const PyRRef& /* unused */) { |
| TORCH_CHECK( |
| false, |
| "Can not pickle rref in python pickler, rref can only be " |
| "pickled when using RPC"); |
| // Note that this return has no meaning since we always |
| // throw, it's only here to satisfy Pybind API's |
| // requirement. |
| return py::make_tuple(); |
| }, |
| /* __setstate__ */ |
| [](py::tuple /* unused */) { // NOLINT |
| TORCH_CHECK( |
| false, |
| "Can not unpickle rref in python pickler, rref can only be " |
| "unpickled when using RPC"); |
| // Note that this return has no meaning since we always |
| // throw, it's only here to satisfy PyBind's API |
| // requirement. |
| return PyRRef( |
| py::cast<py::none>(Py_None), |
| py::cast<py::none>(Py_None)); |
| }), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_serialize", |
| &PyRRef::pickle, |
| py::call_guard<py::gil_scoped_release>()) |
| .def_static( |
| "_deserialize", |
| &PyRRef::unpickle, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_get_type", |
| // Intentionally not releasing GIL, as most accesses just |
| // retrieve cached type py::object |
| &PyRRef::getRRefType, |
| py::arg("timeout") = kUnsetRpcTimeout, |
| R"( |
| Returns the type of the data object referenced by this |
| ``RRef``. On the owner, this is same as |
| ``type(rref.local_value())``. On a user, this will trigger an |
| RPC to fetch the ``type`` object from the owner. After this |
| function is run once, the ``type`` object is cached by the |
| ``RRef``, and subsequent invocations no longer trigger RPC. |
| |
| Args: |
| rref (torch.distributed.rpc.RRef): The RRef to get type of. |
| timeout (float, optional): Timeout, in seconds for |
| ``_get_type``. If the call does not complete within |
| this timeframe, an exception indicating so will be |
| raised. If this argument is not provided, the default |
| RPC timeout will be used. |
| )") |
| .def( |
| "_get_future", |
| [](const PyRRef& self) { |
| return std::make_shared<jit::PythonFutureWrapper>( |
| self.getFuture()); |
| }, |
| py::call_guard<py::gil_scoped_release>(), |
| R"( |
| Returns the future that corresponds to the creation of this RRef |
| on the remote node. This is for internal use cases such as profiling |
| only. |
| )") |
| .def( |
| "_get_profiling_future", |
| [](const PyRRef& self) { |
| return std::make_shared<jit::PythonFutureWrapper>( |
| self.getProfilingFuture()); |
| }, |
| py::call_guard<py::gil_scoped_acquire>(), |
| R"( |
| Returns future that completes when the profiling event corresponding |
| to the creation of this RRef on the remote node has been recorded. |
| )") |
| .def( |
| "_set_profiling_future", |
| [](PyRRef& self, |
| const std::shared_ptr<jit::PythonFutureWrapper>& |
| wrappedFuture) { |
| self.setProfilingFuture(wrappedFuture->fut); |
| }, |
| py::call_guard<py::gil_scoped_acquire>(), |
| R"( |
| Set future that is completed when the profiling event corresponding |
| to the creation of this RRef on the remote node has been recorded. |
| )") |
| .def( |
| "backward", |
| [](PyRRef& self, |
| int64_t dist_autograd_ctx_id, |
| bool retain_graph) { |
| self.backward(dist_autograd_ctx_id, retain_graph); |
| }, |
| py::arg("dist_autograd_ctx_id") = -1, |
| py::arg("retain_graph") = false, |
| py::call_guard<py::gil_scoped_release>(), |
| R"( |
| Runs the backward pass using the RRef as the root of the |
| backward pass. If ``dist_autograd_ctx_id`` is provided, |
| we perform a distributed backward pass using the provided |
| ctx_id starting from the owner of the RRef. In this case, |
| :meth:`~torch.distributed.autograd.get_gradients` should be |
| used to retrieve the gradients. If ``dist_autograd_ctx_id`` |
| is ``None``, it is assumed that this is a local autograd graph |
| and we only perform a local backward pass. In the local case, |
| the node calling this API has to be the owner of the RRef. |
| The value of the RRef is expected to be a scalar Tensor. |
| |
| Args: |
| dist_autograd_ctx_id (int, optional): The distributed |
| autograd context id for which we should retrieve the |
| gradients (default: -1). |
| retain_graph(bool, optional): If ``False``, the graph used to |
| compute the grad will be freed. Note that in nearly all |
| cases setting this option to ``True`` is not needed and |
| often can be worked around in a much more efficient way. |
| Usually, you need to set this to ``True`` to run backward |
| multiple times (default: False). |
| |
| Example:: |
| >>> import torch.distributed.autograd as dist_autograd |
| >>> with dist_autograd.context() as context_id: |
| >>> rref.backward(context_id) |
| )") |
| // not releasing GIL to avoid context switch |
| .def("__repr__", &PyRRef::str); |
| |
| shared_ptr_class_<ProcessGroupRpcBackendOptions>( |
| module, |
| "ProcessGroupRpcBackendOptions", |
| rpcBackendOptions, |
| R"( |
| The backend options class for ``ProcessGroupAgent``, which is derived |
| from ``RpcBackendOptions``. |
| |
| Args: |
| num_send_recv_threads (int, optional): The number of threads in |
| the thread-pool used by ``ProcessGroupAgent`` (default: 4). |
| rpc_timeout (float, optional): The default timeout, in seconds, |
| for RPC requests (default: 60 seconds). If the |
| RPC has not completed in this timeframe, an exception |
| indicating so will be raised. Callers can override this |
| timeout for individual RPCs in |
| :meth:`~torch.distributed.rpc.rpc_sync` and |
| :meth:`~torch.distributed.rpc.rpc_async` if necessary. |
| init_method (str, optional): The URL to initialize |
| ``ProcessGroupGloo`` (default: ``env://``). |
| )") |
| .def( |
| py::init<int, float, std::string>(), |
| py::arg("num_send_recv_threads") = kDefaultNumSendRecvThreads, |
| py::arg("rpc_timeout") = kDefaultRpcTimeoutSeconds, |
| py::arg("init_method") = kDefaultInitMethod) |
| .def_readwrite( |
| "num_send_recv_threads", |
| &ProcessGroupRpcBackendOptions::numSendRecvThreads, |
| R"( |
| The number of threads in the thread-pool used by ProcessGroupAgent. |
| )"); |
| |
| module.attr("_DEFAULT_NUM_SEND_RECV_THREADS") = |
| py::cast(kDefaultNumSendRecvThreads); |
| |
| shared_ptr_class_<ProcessGroupAgent>(module, "ProcessGroupAgent", rpcAgent) |
| .def(py::init([](std::string workerName, |
| const c10::intrusive_ptr<::c10d::ProcessGroup>& pg, |
| int numSendRecvThreads, |
| std::chrono::milliseconds rpcTimeout) { |
| return std::make_unique<ProcessGroupAgent>( |
| std::move(workerName), |
| pg, |
| numSendRecvThreads, |
| rpcTimeout, |
| std::make_unique<RequestCallbackImpl>()); |
| })) |
| .def( |
| "get_worker_info", |
| (const WorkerInfo& (ProcessGroupAgent::*)(void) const) & |
| RpcAgent::getWorkerInfo, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "get_worker_info", |
| (const WorkerInfo& (ProcessGroupAgent::*)(const std::string&) const) & |
| ProcessGroupAgent::getWorkerInfo, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "get_worker_info", |
| (const WorkerInfo& (ProcessGroupAgent::*)(worker_id_t id) const) & |
| ProcessGroupAgent::getWorkerInfo, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "get_worker_infos", |
| (std::vector<WorkerInfo>(ProcessGroupAgent::*)() const) & |
| ProcessGroupAgent::getWorkerInfos, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "join", |
| &ProcessGroupAgent::join, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "shutdown", |
| &ProcessGroupAgent::shutdown, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "sync", |
| &ProcessGroupAgent::sync, |
| py::call_guard<py::gil_scoped_release>()); |
| |
| #ifdef USE_TENSORPIPE |
| |
| // Base class: torch.distributed.rpc.RpcBackendOptions. |
| py::class_<TensorPipeRpcBackendOptions>( |
| module, "_TensorPipeRpcBackendOptionsBase", rpcBackendOptions) |
| .def( |
| py::init< |
| int, |
| optional<std::vector<std::string>>, |
| optional<std::vector<std::string>>, |
| float, |
| std::string, |
| std::unordered_map<std::string, tensorpipe::DeviceMap>>(), |
| py::arg("num_worker_threads") = kDefaultNumWorkerThreads, |
| py::arg("_transports") = optional<std::vector<std::string>>(), |
| py::arg("_channels") = optional<std::vector<std::string>>(), |
| py::arg("rpc_timeout") = kDefaultRpcTimeoutSeconds, |
| py::arg("init_method") = kDefaultInitMethod, |
| py::arg("device_maps") = |
| std::unordered_map<std::string, tensorpipe::DeviceMap>()) |
| .def_readwrite( |
| "num_worker_threads", |
| &TensorPipeRpcBackendOptions::numWorkerThreads, |
| R"( |
| The number of threads in the thread-pool used by |
| :class:`~torch.distributed.rpc.TensorPipeAgent` to execute |
| requests. |
| )") |
| .def_readwrite( |
| "device_maps", |
| &TensorPipeRpcBackendOptions::deviceMaps, |
| R"(The device map locations.)") |
| .def("set_device_map", &TensorPipeRpcBackendOptions::setDeviceMap); |
| |
| module.attr("_DEFAULT_NUM_WORKER_THREADS") = |
| py::cast(kDefaultNumWorkerThreads); |
| |
| shared_ptr_class_<TensorPipeAgent>(module, "TensorPipeAgent", rpcAgent) |
| .def( |
| py::init([](const c10::intrusive_ptr<::c10d::Store>& store, |
| std::string selfName, |
| worker_id_t selfId, |
| int worldSize, |
| c10::intrusive_ptr<::c10d::ProcessGroup> processGroup, |
| TensorPipeRpcBackendOptions opts) { |
| return std::make_shared<TensorPipeAgent>( |
| store, |
| std::move(selfName), |
| selfId, |
| worldSize, |
| std::move(processGroup), |
| std::move(opts), |
| std::make_unique<RequestCallbackImpl>()); |
| }), |
| py::arg("store"), |
| py::arg("name"), |
| py::arg("rank"), |
| py::arg("world_size"), |
| py::arg("process_group"), |
| py::arg("rpc_backend_options")) |
| .def( |
| "join", |
| &TensorPipeAgent::join, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "shutdown", |
| &TensorPipeAgent::shutdown, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "get_worker_info", |
| (const WorkerInfo& (TensorPipeAgent::*)(void) const) & |
| RpcAgent::getWorkerInfo, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "get_worker_info", |
| (const WorkerInfo& (TensorPipeAgent::*)(const std::string&) const) & |
| TensorPipeAgent::getWorkerInfo, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "get_worker_info", |
| (const WorkerInfo& (TensorPipeAgent::*)(worker_id_t id) const) & |
| TensorPipeAgent::getWorkerInfo, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "get_worker_infos", |
| (std::vector<WorkerInfo>(TensorPipeAgent::*)() const) & |
| TensorPipeAgent::getWorkerInfos, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_set_reverse_device_maps", |
| // intentionally not releasing GIL to avoid unnecessary context switch |
| &TensorPipeAgent::setReverseDeviceMaps); |
| |
| #endif // USE_TENSORPIPE |
| |
| module.def("_is_current_rpc_agent_set", &RpcAgent::isCurrentRpcAgentSet); |
| |
| module.def("_get_current_rpc_agent", &RpcAgent::getCurrentRpcAgent); |
| |
| module.def( |
| "_set_and_start_rpc_agent", |
| [](const std::shared_ptr<RpcAgent>& rpcAgent) { |
| RpcAgent::setCurrentRpcAgent(rpcAgent); |
| // Initializing typeResolver inside RpcAgent constructor will make |
| // RpcAgent have python dependency. To avoid RpcAgent to have python |
| // dependency, setTypeResolver() here. |
| std::shared_ptr<TypeResolver> typeResolver = |
| std::make_shared<TypeResolver>([&](const c10::QualifiedName& qn) { |
| auto typePtr = PythonRpcHandler::getInstance().parseTypeFromStr( |
| qn.qualifiedName()); |
| return c10::StrongTypePtr( |
| PythonRpcHandler::getInstance().jitCompilationUnit(), |
| std::move(typePtr)); |
| }); |
| rpcAgent->setTypeResolver(typeResolver); |
| rpcAgent->start(); |
| }, |
| py::call_guard<py::gil_scoped_release>()); |
| |
| module.def("_reset_current_rpc_agent", []() { |
| RpcAgent::setCurrentRpcAgent(nullptr); |
| }); |
| |
| module.def( |
| "_delete_all_user_and_unforked_owner_rrefs", |
| [](std::chrono::milliseconds timeoutMillis) { |
| RRefContext::getInstance().delAllUsersAndUnforkedOwners(timeoutMillis); |
| }, |
| py::arg("timeout") = kDeleteAllUsersTimeout, |
| py::call_guard<py::gil_scoped_release>()); |
| |
| module.def("_destroy_rref_context", [](bool ignoreRRefLeak) { |
| // NB: do not release GIL in the function. The destroyInstance() method |
| // returns a list of deleted OwnerRRefs that hold py::object instances. |
| // Clearing those OwnerRRefs are likely to trigger Python deref, which |
| // requires GIL. |
| RRefContext::getInstance().destroyInstance(ignoreRRefLeak).clear(); |
| }); |
| |
| module.def("_rref_context_get_debug_info", []() { |
| return RRefContext::getInstance().getDebugInfo(); |
| }); |
| |
| module.def( |
| "_cleanup_python_rpc_handler", |
| []() { PythonRpcHandler::getInstance().cleanup(); }, |
| py::call_guard<py::gil_scoped_release>()); |
| |
| module.def( |
| "_invoke_rpc_builtin", |
| [](const WorkerInfo& dst, |
| const std::string& opName, |
| const float rpcTimeoutSeconds, |
| const py::args& args, |
| const py::kwargs& kwargs) { |
| return std::make_shared<jit::PythonFutureWrapper>( |
| pyRpcBuiltin(dst, opName, args, kwargs, rpcTimeoutSeconds)); |
| }, |
| py::call_guard<py::gil_scoped_acquire>()); |
| |
| module.def( |
| "_invoke_rpc_python_udf", |
| [](const WorkerInfo& dst, |
| std::string& pickledPythonUDF, |
| std::vector<torch::Tensor>& tensors, |
| const float rpcTimeoutSeconds, |
| const bool isAsyncExecution) { |
| return std::make_shared<jit::PythonFutureWrapper>( |
| pyRpcPythonUdf( |
| dst, |
| pickledPythonUDF, |
| tensors, |
| rpcTimeoutSeconds, |
| isAsyncExecution), |
| /* unwrap_func */ [](const py::object& value) { |
| py::gil_scoped_release release; |
| auto& pythonRpcHandler = PythonRpcHandler::getInstance(); |
| // This will unwrap RemoteException and raise the contained |
| // server-side Python exception on client side. A caveat here is |
| // that the exception must be raise in the client thread calling |
| // the pybind "wait" API, so that it can be correctly shown to |
| // user. A wrong way is to raise it in RPC server thread, where |
| // the exception would be swallowed in the ThreadPool task, and |
| // also no pybind handling code can help shown the Python |
| // exception. |
| pythonRpcHandler.handleException(value); |
| }); |
| }, |
| py::call_guard<py::gil_scoped_release>()); |
| |
| module.def( |
| "_invoke_rpc_torchscript", |
| [](const std::string& dstWorkerName, |
| const std::string& qualifiedNameStr, |
| const py::tuple& argsTuple, |
| const py::dict& kwargsDict, |
| const float rpcTimeoutSeconds, |
| const bool isAsyncExecution) { |
| return std::make_shared<jit::PythonFutureWrapper>(pyRpcTorchscript( |
| dstWorkerName, |
| qualifiedNameStr, |
| argsTuple, |
| kwargsDict, |
| rpcTimeoutSeconds, |
| isAsyncExecution)); |
| }, |
| py::call_guard<py::gil_scoped_release>()); |
| |
| module.def( |
| "_invoke_remote_builtin", |
| &pyRemoteBuiltin, |
| py::call_guard<py::gil_scoped_acquire>()); |
| |
| module.def( |
| "_invoke_remote_python_udf", |
| &pyRemotePythonUdf, |
| py::call_guard<py::gil_scoped_release>()); |
| |
| module.def( |
| "_invoke_remote_torchscript", |
| &pyRemoteTorchscript, |
| py::call_guard<py::gil_scoped_release>()); |
| |
| module.def( |
| "get_rpc_timeout", |
| []() { |
| return RpcAgent::getCurrentRpcAgent()->getRpcTimeout().count() / |
| kSecToMsConversion; |
| }, |
| R"( |
| Retrieve the default timeout for all RPCs that was set during RPC initialization. |
| The returned value will be in seconds. |
| Returns: |
| ``float`` indicating the RPC timeout in seconds. |
| )"); |
| |
| module.def( |
| "enable_gil_profiling", |
| [](bool flag) { |
| RpcAgent::getCurrentRpcAgent()->enableGILProfiling(flag); |
| }, |
| R"( |
| Set whether GIL wait times should be enabled or not. This incurs a slight |
| overhead cost. Default is disabled for performance reasons. |
| |
| Args: |
| flag (bool): True to set GIL profiling, False to disable. |
| )"); |
| |
| module.def( |
| "_set_rpc_timeout", |
| [](const float rpcTimeoutSeconds) { |
| auto rpcTimeout = std::chrono::milliseconds( |
| static_cast<int>(rpcTimeoutSeconds * kSecToMsConversion)); |
| RpcAgent::getCurrentRpcAgent()->setRpcTimeout(rpcTimeout); |
| }, |
| R"( |
| Set the default timeout for all RPCs. The input unit is expected to be |
| in seconds. If an RPC is not completed within this time, an exception |
| indicating it has timed out will be raised. To control timeout for |
| specific RPCs, a timeout parameter can be passed into |
| :meth:`~torch.distributed.rpc.rpc_sync` and |
| :meth:`~torch.distributed.rpc.rpc_async`. |
| |
| Args: |
| rpcTimeoutSeconds (float): Timeout value in seconds. |
| )"); |
| |
| module.def( |
| "_enable_server_process_global_profiler", |
| &profiler::processglobal::enableServer); |
| module.def( |
| "_disable_server_process_global_profiler", |
| &profiler::processglobal::disableServer); |
| |
| module.def("_set_profiler_node_id", &at::RecordFunction::setDefaultNodeId); |
| |
| py::class_< |
| RemoteProfilerManager, |
| std::unique_ptr<RemoteProfilerManager, py::nodelete>>( |
| module, "RemoteProfilerManager") |
| .def("set_current_profiling_key", [](const std::string& key) { |
| auto& inst = RemoteProfilerManager::getInstance(); |
| inst.setCurrentKey(key); |
| }); |
| |
| module.def( |
| "_enable_jit_rref_pickle", |
| &enableJitRRefPickle, |
| R"( |
| Allows ``torch.jit.save`` to save a ``torch.jit.ScriptModule`` with |
| pickled RRefs out of RPC contexts. |
| |
| |
| .. warning:: |
| This is dangerous. If the module contains RRefs, the pickled |
| result must be sent over RPC and get unpickled on the receiving side |
| to restore the module. Otherwise, there will be RRef leaks, which |
| can potentially lead to program hang. When using this API, it is |
| applications responsibility to make sure that the above assumption |
| always holds. |
| )"); |
| module.def("_disable_jit_rref_pickle", &disableJitRRefPickle); |
| |
| Py_RETURN_TRUE; |
| } |
| |
| } // namespace |
| |
| static PyMethodDef methods[] = { // NOLINT |
| {"_rpc_init", rpc_init, METH_NOARGS, nullptr}, |
| {nullptr, nullptr, 0, nullptr}}; |
| |
| PyMethodDef* python_functions() { |
| return methods; |
| } |
| |
| } // namespace rpc |
| } // namespace distributed |
| } // namespace torch |