Build device generic torch.Stream and torch.Event based on c10::Stream/Event (#123611)
This diff intends to build device generic torch.Stream and torch.Event for newly added accelerators in PyTorch.
------------
**torch.Stream APIs**
```
# Defined in torch/csrc/Stream.cpp
class Stream(_StreamBase):
stream_id: _int # Stream id
device_index: _int
device_type: _int
device: _device # The device of the stream
@overload
def __new__(self, device: Optional[DeviceLikeType] = None, priority: _int = 0) -> Stream: ...
@overload
def __new__(self, stream_id: _int, device_index: _int, device_type: _int, priority: _int = 0) -> Stream: ...
def query(self) -> _bool: ...
def synchronize(self) -> None: ...
def wait_event(self, event: Event) -> None: ...
def wait_stream(self, other: Stream) -> None: ...
def record_event(self, event: Optional[Event] = None) -> Event: ...
def query(self) -> None: ...
def synchronize(self) -> None: ...
def __hash__(self) -> _int: ...
def __repr__(self) -> str: ...
def __eq__(self, other: object) -> _bool: ...
```
------------------
**torch.Event APIs**:
- IPC related APIs are not implemented, since many device backends don't support it, but we leave interfaces there for future adaption of torch.cuda.Stream.
- currently only the enable_timing is supported, since it is the most common one used in other device backends. We have to refactor the event flag system in PyTorch to support more fancy flag.
- elapsedTime API is added to c10::Event
```
# Defined in torch/csrc/Event.cpp
class Event(_EventBase):
device: _device # The device of the Event
event_id: _int # The raw event created by device backend
def __new__(self,
device: Optional[DeviceLikeType] = None,
enable_timing: _bool = False,
blocking: _bool = False,
interprocess: _bool = False) -> Event: ...
@classmethod
def from_ipc_handle(self, device: DeviceLikeType, ipc_handle: bytes) -> Event: ...
def record(self, stream: Optional[Stream] = None) -> None: ...
def wait(self, stream: Optional[Stream] = None) -> None: ...
def query(self) -> _bool: ...
def elapsed_time(self, other: Event) -> _float: ...
def synchronize(self) -> None: ...
def ipc_handle(self) -> bytes: ...
def __repr__(self) -> str: ...
```
-----------
c10::Event provides new APIs
- calculate **elapsedTime**.
- Get raw event id
- Synchronize event.
```
double elapsedTime(const Event& event) const {
return impl_.elapsedTime(event.impl_);
}
void* eventId() const {
return impl_.eventId();
}
void synchronize() const {
return impl_.synchronize();
}
```
----------
TODO: need to find a good way to test them in PyTorch with API mocks.
Differential Revision: [D55351839](https://our.internmc.facebook.com/intern/diff/D55351839/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123611
Approved by: https://github.com/albanD
diff --git a/build_variables.bzl b/build_variables.bzl
index 5716230..71b5c5e 100644
--- a/build_variables.bzl
+++ b/build_variables.bzl
@@ -795,6 +795,7 @@
"torch/csrc/StorageMethods.cpp",
"torch/csrc/StorageSharing.cpp",
"torch/csrc/Stream.cpp",
+ "torch/csrc/Event.cpp",
"torch/csrc/TypeInfo.cpp",
"torch/csrc/api/src/python/init.cpp",
"torch/csrc/autograd/functions/init.cpp",
diff --git a/c10/core/Event.h b/c10/core/Event.h
index 2cbaf18..b94db9f 100644
--- a/c10/core/Event.h
+++ b/c10/core/Event.h
@@ -118,6 +118,18 @@
return impl_.query();
}
+ double elapsedTime(const Event& event) const {
+ return impl_.elapsedTime(event.impl_);
+ }
+
+ void* eventId() const {
+ return impl_.eventId();
+ }
+
+ void synchronize() const {
+ return impl_.synchronize();
+ }
+
private:
impl::InlineEvent<impl::VirtualGuardImpl> impl_;
};
diff --git a/c10/core/impl/DeviceGuardImplInterface.h b/c10/core/impl/DeviceGuardImplInterface.h
index 1b168f7..59210a9 100644
--- a/c10/core/impl/DeviceGuardImplInterface.h
+++ b/c10/core/impl/DeviceGuardImplInterface.h
@@ -123,6 +123,16 @@
}
/**
+ * Return a new stream for a given device and priority. The stream will be
+ * copied and shared around, device backend should be able to correctly handle
+ * the lifetime of the stream.
+ */
+ virtual Stream getNewStream(Device, int priority = 0) const {
+ (void)priority;
+ TORCH_CHECK(false, "Backend doesn't support create a new Stream.")
+ }
+
+ /**
* Set a stream to be the thread local current stream for its device.
* Return the previous stream for that device. You are NOT required
* to set the current device to match the device of this stream.
@@ -195,6 +205,14 @@
}
/**
+ * Wait (by blocking the calling thread) until all the work previously
+ * recorded on the event has completed running on the device.
+ */
+ virtual void synchronizeEvent(void* /*event*/) const {
+ TORCH_CHECK(false, "Backend doesn't support synchronizing events.");
+ }
+
+ /**
* Ensure the caching allocator (if any) is aware that the given DataPtr is
* being used on the given stream, and that it should thus avoid recycling the
* DataPtr until all work on that stream is done.
@@ -203,6 +221,13 @@
}
/**
+ * Fetch the elapsed time between two recorded events.
+ */
+ virtual double elapsedTime(void* /*event1*/, void* /*event2*/) const {
+ TORCH_CHECK(false, "Backend doesn't support elapsedTime.");
+ }
+
+ /**
* Intended use of this class is to leak the DeviceGuardImpl at program end.
* So you better not call the destructor, buster!
*/
@@ -234,6 +259,13 @@
// no-op
return Stream(Stream::DEFAULT, Device(D, -1));
}
+
+ Stream getNewStream(Device, int priority = 0) const override {
+ // no-op
+ (void)priority;
+ return Stream(Stream::DEFAULT, Device(D, -1));
+ }
+
// NB: These do NOT set the current device
Stream exchangeStream(Stream) const noexcept override {
// no-op
diff --git a/c10/core/impl/InlineEvent.h b/c10/core/impl/InlineEvent.h
index ef1e2c6..3485da3 100644
--- a/c10/core/impl/InlineEvent.h
+++ b/c10/core/impl/InlineEvent.h
@@ -101,6 +101,32 @@
return backend_.queryEvent(event_);
}
+ void* eventId() const {
+ return event_;
+ }
+
+ double elapsedTime(const InlineEvent& other) const {
+ TORCH_CHECK(
+ other.was_marked_for_recording(),
+ "other was not marked for recording.");
+ TORCH_CHECK(
+ was_marked_for_recording(), "self was not marked for recording.");
+ TORCH_CHECK(
+ other.device_type() == device_type_,
+ "Event device type ",
+ DeviceTypeName(device_type_),
+ " does not match other's device type ",
+ DeviceTypeName(other.device_type()),
+ ".");
+ return backend_.elapsedTime(event_, other.event_);
+ }
+
+ void synchronize() const {
+ if (!was_marked_for_recording_)
+ return;
+ backend_.synchronizeEvent(event_);
+ }
+
private:
void* event_ = nullptr;
T backend_;
diff --git a/c10/core/impl/VirtualGuardImpl.h b/c10/core/impl/VirtualGuardImpl.h
index ce32411..2065150 100644
--- a/c10/core/impl/VirtualGuardImpl.h
+++ b/c10/core/impl/VirtualGuardImpl.h
@@ -39,6 +39,9 @@
Stream getStream(Device d) const noexcept override {
return impl_->getStream(d);
}
+ Stream getNewStream(Device d, int priority = 0) const override {
+ return impl_->getNewStream(d, priority);
+ }
Stream getDefaultStream(Device d) const override {
return impl_->getDefaultStream(d);
}
@@ -84,6 +87,14 @@
impl_->recordDataPtrOnStream(data_ptr, stream);
}
+ double elapsedTime(void* event1, void* event2) const override {
+ return impl_->elapsedTime(event1, event2);
+ }
+
+ void synchronizeEvent(void* event) const override {
+ return impl_->synchronizeEvent(event);
+ }
+
private:
const DeviceGuardImplInterface* impl_ = nullptr;
};
diff --git a/c10/cuda/impl/CUDAGuardImpl.h b/c10/cuda/impl/CUDAGuardImpl.h
index 7c0ea21..2d983be 100644
--- a/c10/cuda/impl/CUDAGuardImpl.h
+++ b/c10/cuda/impl/CUDAGuardImpl.h
@@ -62,6 +62,9 @@
Stream getDefaultStream(Device d) const override {
return getDefaultCUDAStream(d.index());
}
+ Stream getNewStream(Device d, int priority = 0) const override {
+ return getStreamFromPool(priority, d.index());
+ }
Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false)
const override {
return getStreamFromPool(isHighPriority, d.index());
diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py
index 65aa339..18e373e 100644
--- a/test/test_public_bindings.py
+++ b/test/test_public_bindings.py
@@ -228,6 +228,7 @@
"StaticModule",
"Stream",
"StreamObjType",
+ "Event",
"StringType",
"SUM",
"SymFloat",
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index 1c6e40b..882055e 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -112,7 +112,44 @@
device_index: _int
device_type: _int
- device: device # The device of the stream
+ device: _device # The device of the stream
+
+ @overload
+ def __new__(self, device: Optional[DeviceLikeType] = None, *, priority: _int = 0) -> Stream: ...
+ @overload
+ def __new__(self, stream_id: _int, device_index: _int, device_type: _int, *, priority: _int = 0) -> Stream: ...
+ def query(self) -> _bool: ...
+ def synchronize(self) -> None: ...
+ def wait_event(self, event: Event) -> None: ...
+ def wait_stream(self, other: Stream) -> None: ...
+ def record_event(self, event: Optional[Event] = None) -> Event: ...
+ def __hash__(self) -> _int: ...
+ def __repr__(self) -> str: ...
+ def __eq__(self, other: object) -> _bool: ...
+
+
+# Defined in torch/csrc/Event.cpp
+class Event:
+
+ device: _device # The device of the Event
+ event_id: _int # The raw event created by device backend
+
+ def __new__(self,
+ device: Optional[DeviceLikeType] = None,
+ *,
+ enable_timing: _bool = False,
+ blocking: _bool = False,
+ interprocess: _bool = False) -> Event: ...
+ @classmethod
+ def from_ipc_handle(self, device: _device, ipc_handle: bytes) -> Event: ...
+ def record(self, stream: Optional[Stream] = None) -> None: ...
+ def wait(self, stream: Optional[Stream] = None) -> None: ...
+ def query(self) -> _bool: ...
+ def elapsed_time(self, other: Event) -> _float: ...
+ def synchronize(self) -> None: ...
+ def ipc_handle(self) -> bytes: ...
+ def __repr__(self) -> str: ...
+
# Defined in torch/csrc/Size.cpp
class Size(Tuple[_int, ...]):
diff --git a/torch/csrc/Event.cpp b/torch/csrc/Event.cpp
new file mode 100644
index 0000000..b8cf8b2
--- /dev/null
+++ b/torch/csrc/Event.cpp
@@ -0,0 +1,328 @@
+#include <pybind11/pybind11.h>
+#include <torch/csrc/Device.h>
+#include <torch/csrc/Event.h>
+#include <torch/csrc/Stream.h>
+#include <torch/csrc/THP.h>
+#include <torch/csrc/utils/pybind.h>
+#include <torch/csrc/utils/pycfunction_helpers.h>
+#include <torch/csrc/utils/python_arg_parser.h>
+
+#include <c10/core/Event.h>
+#include <c10/core/Stream.h>
+
+#include <c10/core/DeviceType.h>
+#include <c10/core/impl/DeviceGuardImplInterface.h>
+#include <structmember.h>
+#include <string>
+
+PyObject* THPEventClass = nullptr;
+
+static PyObject* THPEvent_pynew(
+ PyTypeObject* type,
+ PyObject* args,
+ PyObject* kwargs) {
+ HANDLE_TH_ERRORS
+
+ unsigned char enable_timing = 0;
+ unsigned char blocking = 0;
+ unsigned char interprocess = 0;
+
+ static torch::PythonArgParser parser({
+ "Event(Device device=None, *, bool enable_timing=True, bool blocking=False, bool interprocess=False)",
+ });
+
+ torch::ParsedArgs<4> parsed_args;
+ auto r = parser.parse(args, kwargs, parsed_args);
+
+ auto device = r.deviceOptional(0);
+
+ if (!device.has_value()) {
+ device = at::Device(at::getAccelerator(false).value_or(at::kCPU));
+ }
+ enable_timing = r.toBoolWithDefault(1, true);
+ blocking = r.toBoolWithDefault(2, false);
+ interprocess = r.toBoolWithDefault(3, false);
+
+ THPObjectPtr ptr(type->tp_alloc(type, 0));
+ if (!ptr) {
+ TORCH_CHECK(ptr, "Failed to allocate memory for Event");
+ }
+
+ THPEvent* self = (THPEvent*)ptr.get();
+
+ // TODO: blocking and interprocess are not supported yet. To support them, the
+ // flag system of c10::Event needs to be refactored. C10::Event should also
+ // provide a generic constructor to support blocking and interprocess events.
+ (void)blocking;
+ (void)interprocess;
+
+ new (&self->event) c10::Event(
+ device->type(),
+ (enable_timing ? c10::EventFlag::PYTORCH_DEFAULT
+ : c10::EventFlag::BACKEND_DEFAULT));
+
+ return (PyObject*)ptr.release();
+ END_HANDLE_TH_ERRORS
+}
+
+PyObject* THPEvent_new(c10::DeviceType device_type, c10::EventFlag flag) {
+ auto type = (PyTypeObject*)&THPEventType;
+ auto self = THPObjectPtr{type->tp_alloc(type, 0)};
+ TORCH_CHECK(self, "Failed to allocate memory for Event");
+ auto self_ = reinterpret_cast<THPEvent*>(self.get());
+ new (&self_->event) c10::Event(device_type, flag);
+ return self.release();
+}
+
+static void THPEvent_dealloc(THPEvent* self) {
+ {
+ pybind11::gil_scoped_release no_gil{};
+ self->event.~Event();
+ }
+ Py_TYPE(self)->tp_free((PyObject*)self);
+}
+
+static PyObject* THPEvent_get_device(THPEvent* self, void* unused) {
+ HANDLE_TH_ERRORS
+ at::optional<at::Device> device = self->event.device();
+ if (!device) {
+ Py_RETURN_NONE;
+ }
+ return THPDevice_New(device.value());
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject* THPEvent_record(
+ PyObject* _self,
+ PyObject* args,
+ PyObject* kwargs) {
+ HANDLE_TH_ERRORS
+ auto self = (THPEvent*)_self;
+ PyObject* _stream = Py_None;
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
+ constexpr const char* accepted_args[] = {"stream", nullptr};
+ if (!PyArg_ParseTupleAndKeywords(
+ args,
+ kwargs,
+ "|O",
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
+ const_cast<char**>(accepted_args),
+ &_stream)) {
+ TORCH_WARN("Parsing THPEvent_record arg fails");
+ return nullptr;
+ }
+ if (_stream != Py_None) {
+ auto stream = (THPStream*)_stream;
+ self->event.record(c10::Stream::unpack3(
+ stream->stream_id,
+ stream->device_index,
+ static_cast<c10::DeviceType>(stream->device_type)));
+ } else {
+ c10::impl::VirtualGuardImpl impl{
+ static_cast<c10::DeviceType>(self->event.device_type())};
+ self->event.record(impl.getStream(impl.getDevice()));
+ }
+ Py_RETURN_NONE;
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject* THPEvent_from_ipc_handle(
+ PyObject* _type,
+ PyObject* args,
+ PyObject* kwargs) {
+ HANDLE_TH_ERRORS
+ auto type = (PyTypeObject*)_type;
+
+ static torch::PythonArgParser parser({
+ "from_ipc_handle(Device device, std::string ipc_handle)",
+ });
+ torch::ParsedArgs<2> parsed_args;
+ auto r = parser.parse(args, kwargs, parsed_args);
+
+ at::Device device = r.device(0);
+ std::string handle_string = r.string(1);
+ TORCH_CHECK_NOT_IMPLEMENTED(
+ false,
+ "torch.Event ipc is not supported yet, please open an issue if you need this!");
+ THPObjectPtr ptr(type->tp_alloc(type, 0));
+ if (!ptr) {
+ return nullptr;
+ }
+ THPEvent* self = (THPEvent*)ptr.get();
+
+ // TODO: for constructing event from ipc handle, the c10::Event needs to have
+ // more general constructor to achieve that.
+ new (&self->event) c10::Event(device.type(), c10::EventFlag::PYTORCH_DEFAULT);
+
+ return (PyObject*)ptr.release();
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject* THPEvent_ipc_handle(PyObject* _self, PyObject* noargs) {
+ HANDLE_TH_ERRORS
+ auto self = (THPEvent*)_self;
+ (void)self;
+ TORCH_CHECK_NOT_IMPLEMENTED(
+ false,
+ "torch.Event ipc is not supported yet, please open an issue if you need this!");
+ std::string handle = "0";
+ return PyBytes_FromStringAndSize((const char*)&handle, sizeof(handle));
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject* THPEvent_wait(
+ PyObject* _self,
+ PyObject* args,
+ PyObject* kwargs) {
+ HANDLE_TH_ERRORS {
+ auto self = (THPEvent*)_self;
+ PyObject* _stream = Py_None;
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
+ constexpr const char* accepted_args[] = {"stream", nullptr};
+ if (!PyArg_ParseTupleAndKeywords(
+ args,
+ kwargs,
+ "|O",
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
+ const_cast<char**>(accepted_args),
+ &_stream)) {
+ TORCH_WARN("Parsing THPEvent_wait arg fails");
+ return nullptr;
+ }
+ if (_stream != Py_None) {
+ auto stream = (THPStream*)_stream;
+ self->event.block(c10::Stream::unpack3(
+ stream->stream_id,
+ stream->device_index,
+ static_cast<c10::DeviceType>(stream->device_type)));
+ } else {
+ c10::impl::VirtualGuardImpl impl{
+ static_cast<c10::DeviceType>(self->event.device_type())};
+ self->event.block(impl.getStream(impl.getDevice()));
+ }
+ }
+ Py_RETURN_NONE;
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject* THPEvent_query(PyObject* _self, PyObject* noargs) {
+ HANDLE_TH_ERRORS
+ auto self = (THPEvent*)_self;
+ return PyBool_FromLong(self->event.query());
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject* THPEvent_elapsed_time(PyObject* _self, PyObject* _other) {
+ HANDLE_TH_ERRORS
+ auto self = (THPEvent*)_self;
+ auto other = (THPEvent*)_other;
+ return PyFloat_FromDouble(self->event.elapsedTime(other->event));
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject* THPEvent_synchronize(PyObject* _self, PyObject* noargs) {
+ HANDLE_TH_ERRORS {
+ pybind11::gil_scoped_release no_gil{};
+ auto self = (THPEvent*)_self;
+ self->event.synchronize();
+ }
+ Py_RETURN_NONE;
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject* THPEvent_evend_id(PyObject* _self, PyObject* noargs) {
+ HANDLE_TH_ERRORS
+ auto self = (THPEvent*)_self;
+ return PyLong_FromVoidPtr(self->event.eventId());
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject* THPEvent_repr(THPEvent* self) {
+ HANDLE_TH_ERRORS
+ return THPUtils_packString(
+ "torch.Event device_type=" +
+ c10::DeviceTypeName(
+ static_cast<c10::DeviceType>(self->event.device_type()), true) +
+ ", device_index=" + std::to_string(self->event.device_index()) +
+ ", event_flag=" +
+ std::to_string(static_cast<int64_t>(self->event.flag())) + ", event_id=" +
+ std::to_string(reinterpret_cast<int64_t>(self->event.eventId())));
+ END_HANDLE_TH_ERRORS
+}
+
+// NOLINTNEXTLINE(*c-arrays*, *global-variables)
+static struct PyGetSetDef THPEvent_properties[] = {
+ {"device", (getter)THPEvent_get_device, nullptr, nullptr, nullptr},
+ {"event_id", (getter)THPEvent_evend_id, nullptr, nullptr, nullptr},
+ {nullptr}};
+
+// NOLINTNEXTLINE(*c-arrays*, *global-variables)
+static PyMethodDef THPEvent_methods[] = {
+ {(char*)"from_ipc_handle",
+ castPyCFunctionWithKeywords(THPEvent_from_ipc_handle),
+ METH_CLASS | METH_VARARGS | METH_KEYWORDS,
+ nullptr},
+ {(char*)"record",
+ castPyCFunctionWithKeywords(THPEvent_record),
+ METH_VARARGS | METH_KEYWORDS,
+ nullptr},
+ {(char*)"wait",
+ castPyCFunctionWithKeywords(THPEvent_wait),
+ METH_VARARGS | METH_KEYWORDS,
+ nullptr},
+ {(char*)"query", THPEvent_query, METH_NOARGS, nullptr},
+ {(char*)"elapsed_time", THPEvent_elapsed_time, METH_O, nullptr},
+ {(char*)"synchronize", THPEvent_synchronize, METH_NOARGS, nullptr},
+ {(char*)"ipc_handle", THPEvent_ipc_handle, METH_NOARGS, nullptr},
+ {nullptr}};
+
+PyTypeObject THPEventType = {
+ PyVarObject_HEAD_INIT(nullptr, 0) "torch.Event", /* tp_name */
+ sizeof(THPEvent), /* tp_basicsize */
+ 0, /* tp_itemsize */
+ (destructor)THPEvent_dealloc, /* tp_dealloc */
+ 0, /* tp_vectorcall_offset */
+ nullptr, /* tp_getattr */
+ nullptr, /* tp_setattr */
+ nullptr, /* tp_reserved */
+ (reprfunc)THPEvent_repr, /* tp_repr */
+ nullptr, /* tp_as_number */
+ nullptr, /* tp_as_sequence */
+ nullptr, /* tp_as_mapping */
+ nullptr, /* tp_hash */
+ nullptr, /* tp_call */
+ nullptr, /* tp_str */
+ nullptr, /* tp_getattro */
+ nullptr, /* tp_setattro */
+ nullptr, /* tp_as_buffer */
+ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
+ nullptr, /* tp_doc */
+ nullptr, /* tp_traverse */
+ nullptr, /* tp_clear */
+ nullptr, /* tp_richcompare */
+ 0, /* tp_weaklistoffset */
+ nullptr, /* tp_iter */
+ nullptr, /* tp_iternext */
+ THPEvent_methods, /* tp_methods */
+ nullptr, /* tp_members */
+ THPEvent_properties, /* tp_getset */
+ nullptr, /* tp_base */
+ nullptr, /* tp_dict */
+ nullptr, /* tp_descr_get */
+ nullptr, /* tp_descr_set */
+ 0, /* tp_dictoffset */
+ nullptr, /* tp_init */
+ nullptr, /* tp_alloc */
+ THPEvent_pynew, /* tp_new */
+};
+
+void THPEvent_init(PyObject* module) {
+ THPEventClass = (PyObject*)&THPEventType;
+ if (PyType_Ready(&THPEventType) < 0) {
+ throw python_error();
+ }
+ Py_INCREF(&THPEventType);
+ if (PyModule_AddObject(module, "Event", (PyObject*)&THPEventType) < 0) {
+ throw python_error();
+ }
+}
diff --git a/torch/csrc/Event.h b/torch/csrc/Event.h
new file mode 100644
index 0000000..745610d
--- /dev/null
+++ b/torch/csrc/Event.h
@@ -0,0 +1,21 @@
+#ifndef THP_EVENT_INC
+#define THP_EVENT_INC
+
+#include <c10/core/Event.h>
+#include <torch/csrc/python_headers.h>
+
+struct TORCH_API THPEvent {
+ PyObject_HEAD c10::Event event;
+};
+extern PyObject* THPEventClass;
+TORCH_API extern PyTypeObject THPEventType;
+
+TORCH_API void THPEvent_init(PyObject* module);
+TORCH_API PyObject* THPEvent_new(
+ c10::DeviceType device_type,
+ c10::EventFlag flag);
+inline bool THPEvent_Check(PyObject* obj) {
+ return THPEventClass && PyObject_IsInstance(obj, THPEventClass);
+}
+
+#endif // THP_EVENT_INC
diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp
index 9343a48..8aff730 100644
--- a/torch/csrc/Module.cpp
+++ b/torch/csrc/Module.cpp
@@ -39,6 +39,7 @@
#include <torch/csrc/Device.h>
#include <torch/csrc/Dtype.h>
#include <torch/csrc/DynamicTypes.h>
+#include <torch/csrc/Event.h>
#include <torch/csrc/Generator.h>
#include <torch/csrc/Layout.h>
#include <torch/csrc/MemoryFormat.h>
@@ -1603,6 +1604,7 @@
THPQScheme_init(module);
THPDevice_init(module);
THPStream_init(module);
+ THPEvent_init(module);
ASSERT_TRUE(THPVariable_initModule(module));
ASSERT_TRUE(THPFunction_initModule(module));
ASSERT_TRUE(THPEngine_initModule(module));
diff --git a/torch/csrc/Stream.cpp b/torch/csrc/Stream.cpp
index bd8abb0..06dac51 100644
--- a/torch/csrc/Stream.cpp
+++ b/torch/csrc/Stream.cpp
@@ -1,10 +1,19 @@
#include <pybind11/pybind11.h>
#include <torch/csrc/Device.h>
+#include <torch/csrc/Event.h>
+#include <torch/csrc/Stream.h>
#include <torch/csrc/THP.h>
#include <torch/csrc/utils/pybind.h>
+#include <torch/csrc/utils/pycfunction_helpers.h>
#include <torch/csrc/utils/python_arg_parser.h>
+#include <c10/core/DeviceGuard.h>
+#include <c10/core/Stream.h>
+#include <c10/core/impl/DeviceGuardImplInterface.h>
+#include <c10/util/Exception.h>
+#include <c10/util/hash.h>
#include <structmember.h>
+#include <cstdint>
PyTypeObject* THPStreamClass = nullptr;
@@ -13,22 +22,53 @@
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
- int64_t stream_id = 0;
- int64_t device_index = 0;
+
+ int64_t stream_id = -1;
int64_t device_type = 0;
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
- constexpr const char* kwlist[] = {
- "stream_id", "device_index", "device_type", nullptr};
- if (!PyArg_ParseTupleAndKeywords(
- args,
- kwargs,
- "|LLL",
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
- const_cast<char**>(kwlist),
- &stream_id,
- &device_index,
- &device_type)) {
- return nullptr;
+ int64_t device_index = 0;
+ int64_t priority = 0;
+
+ static torch::PythonArgParser parser({
+ "Steram(Device device=None, *, int64_t priority=0)",
+ "Stream(int64_t stream_id, int64_t device_index, int64_t device_type, *, int64_t priority=0)",
+ });
+
+ torch::ParsedArgs<4> parsed_args;
+ auto r = parser.parse(args, kwargs, parsed_args);
+
+ std::unique_ptr<c10::DeviceGuard> device_guard_ptr;
+
+ if (r.idx == 0) {
+ auto default_accelerator = at::getAccelerator(false);
+ auto device = r.deviceOptional(0);
+ if (device.has_value()) {
+ device_type = static_cast<int64_t>(device->type());
+ device_index = static_cast<int64_t>(device->index());
+ // Initialize device guard if device is not None.
+ device_guard_ptr = std::make_unique<c10::DeviceGuard>(device.value());
+ } else {
+ // If device is None, we will use the current accelerator and index.
+ // If the current accelerator is not set, we will use the CPU as device
+ // type.
+ device_type = static_cast<int64_t>(
+ default_accelerator.value_or(c10::DeviceType::CPU));
+ c10::impl::VirtualGuardImpl impl{
+ static_cast<c10::DeviceType>(device_type)};
+ const auto current_device = impl.getDevice();
+ device_index = current_device.index();
+ }
+ priority = r.toInt64WithDefault(1, 0);
+ } else if (r.idx == 1) {
+ stream_id = r.toInt64WithDefault(0, -1);
+ device_index = r.toInt64WithDefault(1, 0);
+ device_type =
+ r.toInt64WithDefault(2, static_cast<int64_t>(c10::DeviceType::CPU));
+ priority = r.toInt64WithDefault(3, 0);
+ } else {
+ TORCH_CHECK(
+ false,
+ "parse stream arg fails please check the usage: ",
+ parser.get_signatures());
}
THPObjectPtr ptr(type->tp_alloc(type, 0));
@@ -37,9 +77,29 @@
}
THPStream* self = (THPStream*)ptr.get();
- self->stream_id = stream_id;
- self->device_index = device_index;
- self->device_type = device_type;
+
+ // If torch.Stream is not created from existing Stream, then create a new one.
+ // It requires other device backends override getNewStream method. How the new
+ // stream is created is backend specific. Backend should be able to correctly
+ // manage the lifetime of streams.
+ c10::optional<c10::Stream> stream_opt;
+ if (r.idx == 0) {
+ c10::impl::VirtualGuardImpl impl{static_cast<c10::DeviceType>(device_type)};
+ stream_opt = impl.getNewStream(
+ c10::Device(static_cast<c10::DeviceType>(device_type), device_index),
+ static_cast<int>(priority));
+ } else {
+ stream_opt = c10::Stream::unpack3(
+ stream_id,
+ static_cast<c10::DeviceIndex>(device_index),
+ static_cast<c10::DeviceType>(device_type));
+ }
+
+ TORCH_CHECK(stream_opt.has_value(), "Failed to create stream");
+ self->stream_id = static_cast<int64_t>(stream_opt->id());
+ self->device_index = static_cast<int64_t>(stream_opt->device_index());
+ self->device_type = static_cast<int64_t>(stream_opt->device_type());
+
return (PyObject*)ptr.release();
END_HANDLE_TH_ERRORS
}
@@ -73,15 +133,167 @@
END_HANDLE_TH_ERRORS
}
+static PyObject* THPStream_query(PyObject* _self, PyObject* noargs) {
+ HANDLE_TH_ERRORS
+ auto self = (THPStream*)_self;
+
+ return PyBool_FromLong(c10::Stream::unpack3(
+ self->stream_id,
+ self->device_index,
+ static_cast<c10::DeviceType>(self->device_type))
+ .query());
+
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject* THPStream_synchronize(PyObject* _self, PyObject* noargs) {
+ HANDLE_TH_ERRORS {
+ pybind11::gil_scoped_release no_gil;
+ auto self = (THPStream*)_self;
+
+ c10::Stream::unpack3(
+ self->stream_id,
+ self->device_index,
+ static_cast<c10::DeviceType>(self->device_type))
+ .synchronize();
+ }
+ Py_RETURN_NONE;
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject* THPStream_wait_event(PyObject* _self, PyObject* _event) {
+ HANDLE_TH_ERRORS {
+ auto self = (THPStream*)_self;
+ auto event = (THPEvent*)_event;
+ c10::Stream::unpack3(
+ self->stream_id,
+ self->device_index,
+ static_cast<c10::DeviceType>(self->device_type))
+ .wait(event->event);
+ }
+ Py_RETURN_NONE;
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject* THPStream_wait_stream(PyObject* _self, PyObject* _other) {
+ HANDLE_TH_ERRORS {
+ auto self = (THPStream*)_self;
+ auto other_stream = (THPStream*)_other;
+ c10::Event new_event(
+ static_cast<c10::DeviceType>(other_stream->device_type),
+ c10::EventFlag::PYTORCH_DEFAULT);
+ new_event.record(c10::Stream::unpack3(
+ other_stream->stream_id,
+ other_stream->device_index,
+ static_cast<c10::DeviceType>(other_stream->device_type)));
+ c10::Stream::unpack3(
+ self->stream_id,
+ self->device_index,
+ static_cast<c10::DeviceType>(self->device_type))
+ .wait(new_event);
+ }
+ Py_RETURN_NONE;
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject* THPStream_record_event(
+ PyObject* _self,
+ PyObject* args,
+ PyObject* kwargs) {
+ HANDLE_TH_ERRORS
+ auto self = (THPStream*)_self;
+ PyObject* _new_event;
+ PyObject* _event = Py_None;
+
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
+ constexpr const char* accepted_args[] = {"event", nullptr};
+ if (!PyArg_ParseTupleAndKeywords(
+ args,
+ kwargs,
+ "|O",
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
+ const_cast<char**>(accepted_args),
+ &_event)) {
+ TORCH_CHECK(false, "parse record_event arg fails");
+ }
+ if (_event != Py_None) {
+ // Increase the refcount of the event to avoid it being destroyed.
+ Py_INCREF(_event);
+ _new_event = _event;
+ } else {
+ _new_event = THPEvent_new(
+ static_cast<c10::DeviceType>(self->device_type),
+ c10::EventFlag::PYTORCH_DEFAULT);
+ }
+ auto new_event = (THPEvent*)_new_event;
+ TORCH_CHECK(new_event, "event must not be null");
+ new_event->event.record(c10::Stream::unpack3(
+ self->stream_id,
+ self->device_index,
+ static_cast<c10::DeviceType>(self->device_type)));
+ return (PyObject*)new_event;
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject* THPStream_repr(THPStream* self) {
+ HANDLE_TH_ERRORS
+ return THPUtils_packString(
+ "torch.Stream device_type=" +
+ c10::DeviceTypeName(
+ static_cast<c10::DeviceType>(self->device_type), true) +
+ ", device_index=" + std::to_string(self->device_index) +
+ ", stream_id=" + std::to_string(self->stream_id));
+ END_HANDLE_TH_ERRORS
+}
+
+static Py_hash_t THPStream_hash(THPStream* self) {
+ return static_cast<long>(at::hash_combine(
+ self->device_type,
+ (at::hash_combine(self->stream_id, self->device_index))));
+}
+
static PyObject* THPStream_eq(THPStream* self, THPStream* other) {
HANDLE_TH_ERRORS
return PyBool_FromLong(
- self->stream_id == other->stream_id &&
- self->device_index == other->device_index &&
- self->device_type == other->device_type);
+ (self->stream_id == other->stream_id) &&
+ (self->device_index == other->device_index) &&
+ (self->device_type == other->device_type));
END_HANDLE_TH_ERRORS
}
+static PyObject* THPStream_ne(THPStream* self, THPStream* other) {
+ HANDLE_TH_ERRORS
+ return PyBool_FromLong(
+ (self->stream_id != other->stream_id) ||
+ (self->device_index != other->device_index) ||
+ (self->device_type != other->device_type));
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject* THPStream_richcompare(
+ PyObject* self,
+ PyObject* other,
+ int op) {
+ PyObject* result = NULL;
+ if (other == Py_None) {
+ result = Py_False;
+ } else {
+ switch (op) {
+ case Py_EQ:
+ result = THPStream_eq((THPStream*)self, (THPStream*)other);
+ break;
+ case Py_NE:
+ result = THPStream_ne((THPStream*)self, (THPStream*)other);
+ break;
+ default:
+ result = Py_False;
+ break;
+ }
+ }
+ Py_XINCREF(result);
+ return result;
+}
+
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
static struct PyMemberDef THPStream_members[] = {
{"stream_id",
@@ -108,6 +320,14 @@
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
static PyMethodDef THPStream_methods[] = {
+ {"query", THPStream_query, METH_NOARGS, nullptr},
+ {"synchronize", THPStream_synchronize, METH_NOARGS, nullptr},
+ {"wait_event", THPStream_wait_event, METH_O, nullptr},
+ {"wait_stream", THPStream_wait_stream, METH_O, nullptr},
+ {"record_event",
+ castPyCFunctionWithKeywords(THPStream_record_event),
+ METH_VARARGS | METH_KEYWORDS,
+ nullptr},
{"__eq__", (PyCFunction)THPStream_eq, METH_O, nullptr},
{nullptr}};
@@ -120,11 +340,11 @@
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
- nullptr, /* tp_repr */
+ (reprfunc)THPStream_repr, /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
- nullptr, /* tp_hash */
+ (hashfunc)THPStream_hash, /* tp_hash */
nullptr, /* tp_call */
nullptr, /* tp_str */
nullptr, /* tp_getattro */
@@ -135,7 +355,7 @@
nullptr, /* tp_doc */
nullptr, /* tp_traverse */
nullptr, /* tp_clear */
- nullptr, /* tp_richcompare */
+ THPStream_richcompare, /* tp_richcompare */
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */