| #pragma once |
| |
| #include <torch/csrc/distributed/c10d/ProcessGroup.hpp> |
| #include <torch/csrc/jit/python/pybind_utils.h> |
| #include <torch/csrc/utils/pybind.h> |
| |
| namespace c10d { |
| |
| // PyProcessGroup is a pybind11 trampoline class to allow a Python |
| // class to inherit from torch.distributed.ProcessGroup |
| class PyProcessGroup : public ProcessGroup { |
| public: |
| // PyWork is a pybind11 trampoline class to allow a Python |
| // class to inherit from torch.distributed.Work |
| class PyWork : public Work { |
| public: |
| PyWork() = default; |
| |
| bool wait(std::chrono::milliseconds timeout = kNoTimeout) override { |
| PYBIND11_OVERRIDE( |
| bool, /* Return type */ |
| Work, /* Parent class */ |
| wait, /* Name of function in C++ */ |
| timeout); |
| } |
| |
| c10::intrusive_ptr<c10::ivalue::Future> getFuture() override { |
| // We cannot use PYBIND11_OVERRIDE because: |
| // 1. We have to >MANUALLY< unwrap the PyFutureWrapper and |
| // 2. The python name is get_future |
| pybind11::gil_scoped_acquire gil; |
| auto override = |
| pybind11::get_override(static_cast<const Work*>(this), "get_future"); |
| |
| if (override) { |
| py::object o = override(); |
| auto futWrapper = |
| o.cast<std::shared_ptr<torch::jit::PythonFutureWrapper>>(); |
| return futWrapper->fut; |
| } |
| |
| return Work::getFuture(); |
| } |
| }; |
| |
| using ProcessGroup::ProcessGroup; |
| |
| const std::string getBackendName() const override { |
| PYBIND11_OVERRIDE_PURE( |
| std::string, /* Return type */ |
| ProcessGroup, /* Parent class */ |
| getBackendName, /* Name of function in C++ */ |
| ); |
| } |
| |
| c10::intrusive_ptr<Work> allgather( |
| std::vector<std::vector<at::Tensor>>& outputTensors, |
| std::vector<at::Tensor>& inputTensors, |
| const AllgatherOptions& opts = AllgatherOptions()) override { |
| PYBIND11_OVERRIDE( |
| c10::intrusive_ptr<Work>, /* Return type */ |
| ProcessGroup, /* Parent class */ |
| allgather, /* Name of function in C++ */ |
| outputTensors, |
| inputTensors, |
| opts); |
| } |
| |
| c10::intrusive_ptr<Work> allreduce( |
| std::vector<at::Tensor>& tensors, |
| const AllreduceOptions& opts = AllreduceOptions()) override { |
| PYBIND11_OVERRIDE( |
| c10::intrusive_ptr<Work>, /* Return type */ |
| ProcessGroup, /* Parent class */ |
| allreduce, /* Name of function in C++ */ |
| tensors, |
| opts); |
| } |
| |
| c10::intrusive_ptr<Work> barrier( |
| const BarrierOptions& opts = BarrierOptions()) override { |
| PYBIND11_OVERRIDE( |
| c10::intrusive_ptr<Work>, /* Return type */ |
| ProcessGroup, /* Parent class */ |
| barrier, /* Name of function in C++ */ |
| opts); |
| } |
| |
| c10::intrusive_ptr<Work> broadcast( |
| std::vector<at::Tensor>& tensors, |
| const BroadcastOptions& opts = BroadcastOptions()) override { |
| PYBIND11_OVERRIDE( |
| c10::intrusive_ptr<Work>, /* Return type */ |
| ProcessGroup, /* Parent class */ |
| broadcast, /* Name of function in C++ */ |
| tensors, |
| opts); |
| } |
| |
| c10::intrusive_ptr<Work> reduce_scatter( |
| std::vector<at::Tensor>& outputTensors, |
| std::vector<std::vector<at::Tensor>>& inputTensors, |
| const ReduceScatterOptions& opts = ReduceScatterOptions()) override { |
| PYBIND11_OVERRIDE( |
| c10::intrusive_ptr<Work>, /* Return type */ |
| ProcessGroup, /* Parent class */ |
| reduce_scatter, /* Name of function in C++ */ |
| outputTensors, |
| inputTensors, |
| opts); |
| } |
| |
| c10::intrusive_ptr<Work> send( |
| std::vector<at::Tensor>& tensors, |
| int dstRank, |
| int tag) override { |
| PYBIND11_OVERRIDE( |
| c10::intrusive_ptr<Work>, /* Return type */ |
| ProcessGroup, /* Parent class */ |
| send, /* Name of function in C++ */ |
| tensors, |
| dstRank, |
| tag); |
| } |
| |
| c10::intrusive_ptr<Work> recv( |
| std::vector<at::Tensor>& tensors, |
| int srcRank, |
| int tag) override { |
| PYBIND11_OVERRIDE( |
| c10::intrusive_ptr<Work>, /* Return type */ |
| ProcessGroup, /* Parent class */ |
| recv, /* Name of function in C++ */ |
| tensors, |
| srcRank, |
| tag); |
| } |
| }; |
| |
| class TORCH_PYTHON_API PythonOnCompletionHook { |
| public: |
| // Wraps a py::object hook and acquires Python GIL in dtor before |
| // destructing the hook object. |
| PythonOnCompletionHook(py::object hook) : hook_(std::move(hook)) {} |
| |
| ~PythonOnCompletionHook() { |
| py::gil_scoped_acquire ag; |
| hook_.dec_ref(); |
| // Explicitly set hook_ to nullptr to prevent py::object's dtor |
| // to decref on the PyObject again. |
| // See Note [Destructing py::object] in python_ivalue.h |
| hook_.ptr() = nullptr; |
| } |
| |
| void operator()(std::shared_ptr<WorkInfo> workInfo) const { |
| std::exception_ptr eptr; |
| { |
| py::gil_scoped_acquire acquire; |
| try { |
| hook_(workInfo); |
| } catch (py::error_already_set& e) { |
| // py::error_already_set requires GIL to destruct, take |
| // special care. |
| eptr = std::make_exception_ptr(std::runtime_error(e.what())); |
| e.restore(); |
| PyErr_Clear(); |
| } catch (std::exception& e) { |
| eptr = std::current_exception(); |
| } |
| } |
| // No more Python-related stuff at this point, i.e., this |
| // exception can be captured and handled by PG backend. |
| if (eptr) |
| std::rethrow_exception(eptr); |
| } |
| |
| private: |
| py::object hook_; |
| }; |
| |
| } // namespace c10d |