| #include <pybind11/pybind11.h> |
| #include <torch/csrc/Device.h> |
| #include <torch/csrc/Event.h> |
| #include <torch/csrc/Stream.h> |
| #include <torch/csrc/THP.h> |
| #include <torch/csrc/utils/pybind.h> |
| #include <torch/csrc/utils/pycfunction_helpers.h> |
| #include <torch/csrc/utils/python_arg_parser.h> |
| |
| #include <c10/core/DeviceGuard.h> |
| #include <c10/core/Stream.h> |
| #include <c10/core/impl/DeviceGuardImplInterface.h> |
| #include <c10/util/Exception.h> |
| #include <c10/util/hash.h> |
| #include <structmember.h> |
| #include <cstdint> |
| |
| PyTypeObject* THPStreamClass = nullptr; |
| |
| static PyObject* THPStream_pynew( |
| PyTypeObject* type, |
| PyObject* args, |
| PyObject* kwargs) { |
| HANDLE_TH_ERRORS |
| |
| int64_t stream_id = -1; |
| int64_t device_type = 0; |
| int64_t device_index = 0; |
| int64_t priority = 0; |
| |
| static torch::PythonArgParser parser({ |
| "Stream(Device device=None, *, int64_t priority=0)", |
| "Stream(int64_t stream_id, int64_t device_index, int64_t device_type, *, int64_t priority=0)", |
| }); |
| |
| torch::ParsedArgs<4> parsed_args; |
| auto r = parser.parse(args, kwargs, parsed_args); |
| |
| std::unique_ptr<c10::DeviceGuard> device_guard_ptr; |
| |
| if (r.idx == 0) { |
| auto default_accelerator = at::getAccelerator(false); |
| auto device = r.deviceOptional(0); |
| if (device.has_value()) { |
| device_type = static_cast<int64_t>(device->type()); |
| device_index = static_cast<int64_t>(device->index()); |
| // Initialize device guard if device is not None. |
| device_guard_ptr = std::make_unique<c10::DeviceGuard>(device.value()); |
| } else { |
| // If device is None, we will use the current accelerator and index. |
| // If the current accelerator is not set, we will use the CPU as device |
| // type. |
| device_type = static_cast<int64_t>( |
| default_accelerator.value_or(c10::DeviceType::CPU)); |
| c10::impl::VirtualGuardImpl impl{ |
| static_cast<c10::DeviceType>(device_type)}; |
| const auto current_device = impl.getDevice(); |
| device_index = current_device.index(); |
| } |
| priority = r.toInt64WithDefault(1, 0); |
| } else if (r.idx == 1) { |
| stream_id = r.toInt64WithDefault(0, -1); |
| device_index = r.toInt64WithDefault(1, 0); |
| device_type = |
| r.toInt64WithDefault(2, static_cast<int64_t>(c10::DeviceType::CPU)); |
| priority = r.toInt64WithDefault(3, 0); |
| } else { |
| TORCH_CHECK( |
| false, |
| "parse stream arg fails please check the usage: ", |
| parser.get_signatures()); |
| } |
| |
| THPObjectPtr ptr(type->tp_alloc(type, 0)); |
| if (!ptr) { |
| return nullptr; |
| } |
| |
| THPStream* self = (THPStream*)ptr.get(); |
| |
| // If torch.Stream is not created from existing Stream, then create a new one. |
| // It requires other device backends override getNewStream method. How the new |
| // stream is created is backend specific. Backend should be able to correctly |
| // manage the lifetime of streams. |
| std::optional<c10::Stream> stream_opt; |
| if (r.idx == 0) { |
| c10::impl::VirtualGuardImpl impl{static_cast<c10::DeviceType>(device_type)}; |
| stream_opt = impl.getNewStream( |
| c10::Device(static_cast<c10::DeviceType>(device_type), device_index), |
| static_cast<int>(priority)); |
| } else { |
| stream_opt = c10::Stream::unpack3( |
| stream_id, |
| static_cast<c10::DeviceIndex>(device_index), |
| static_cast<c10::DeviceType>(device_type)); |
| } |
| |
| TORCH_CHECK(stream_opt.has_value(), "Failed to create stream"); |
| self->stream_id = static_cast<int64_t>(stream_opt->id()); |
| self->device_index = static_cast<int64_t>(stream_opt->device_index()); |
| self->device_type = static_cast<int64_t>(stream_opt->device_type()); |
| |
| return (PyObject*)ptr.release(); |
| END_HANDLE_TH_ERRORS |
| } |
| |
| PyObject* THPStream_Wrap(const c10::Stream& stream) { |
| HANDLE_TH_ERRORS |
| auto type = (PyTypeObject*)THPStreamClass; |
| THPObjectPtr ptr(type->tp_alloc(type, 0)); |
| if (!ptr) { |
| throw python_error(); |
| } |
| |
| THPStream* self = (THPStream*)ptr.get(); |
| self->stream_id = stream.id(); |
| // NOLINTNEXTLINE(bugprone-signed-char-misuse) |
| self->device_index = static_cast<int64_t>(stream.device_index()); |
| self->device_type = static_cast<int64_t>(stream.device_type()); |
| return ptr.release(); |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static void THPStream_dealloc(THPStream* self) { |
| Py_TYPE(self)->tp_free((PyObject*)self); |
| } |
| |
| static PyObject* THPStream_get_device(THPStream* self, void* unused) { |
| HANDLE_TH_ERRORS |
| return THPDevice_New(c10::Device( |
| static_cast<c10::DeviceType>(self->device_type), |
| static_cast<c10::DeviceIndex>(self->device_index))); |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* THPStream_query(PyObject* _self, PyObject* noargs) { |
| HANDLE_TH_ERRORS |
| auto self = (THPStream*)_self; |
| |
| return PyBool_FromLong(c10::Stream::unpack3( |
| self->stream_id, |
| self->device_index, |
| static_cast<c10::DeviceType>(self->device_type)) |
| .query()); |
| |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* THPStream_synchronize(PyObject* _self, PyObject* noargs) { |
| HANDLE_TH_ERRORS { |
| pybind11::gil_scoped_release no_gil; |
| auto self = (THPStream*)_self; |
| |
| c10::Stream::unpack3( |
| self->stream_id, |
| self->device_index, |
| static_cast<c10::DeviceType>(self->device_type)) |
| .synchronize(); |
| } |
| Py_RETURN_NONE; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* THPStream_wait_event(PyObject* _self, PyObject* _event) { |
| HANDLE_TH_ERRORS { |
| auto self = (THPStream*)_self; |
| auto event = (THPEvent*)_event; |
| c10::Stream::unpack3( |
| self->stream_id, |
| self->device_index, |
| static_cast<c10::DeviceType>(self->device_type)) |
| .wait(event->event); |
| } |
| Py_RETURN_NONE; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* THPStream_wait_stream(PyObject* _self, PyObject* _other) { |
| HANDLE_TH_ERRORS { |
| auto self = (THPStream*)_self; |
| auto other_stream = (THPStream*)_other; |
| c10::Event new_event( |
| static_cast<c10::DeviceType>(other_stream->device_type), |
| c10::EventFlag::PYTORCH_DEFAULT); |
| new_event.record(c10::Stream::unpack3( |
| other_stream->stream_id, |
| other_stream->device_index, |
| static_cast<c10::DeviceType>(other_stream->device_type))); |
| c10::Stream::unpack3( |
| self->stream_id, |
| self->device_index, |
| static_cast<c10::DeviceType>(self->device_type)) |
| .wait(new_event); |
| } |
| Py_RETURN_NONE; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* THPStream_record_event( |
| PyObject* _self, |
| PyObject* args, |
| PyObject* kwargs) { |
| HANDLE_TH_ERRORS |
| auto self = (THPStream*)_self; |
| PyObject* _new_event; |
| PyObject* _event = Py_None; |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) |
| constexpr const char* accepted_args[] = {"event", nullptr}; |
| if (!PyArg_ParseTupleAndKeywords( |
| args, |
| kwargs, |
| "|O", |
| // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) |
| const_cast<char**>(accepted_args), |
| &_event)) { |
| TORCH_CHECK(false, "parse record_event arg fails"); |
| } |
| if (_event != Py_None) { |
| // Increase the refcount of the event to avoid it being destroyed. |
| Py_INCREF(_event); |
| _new_event = _event; |
| } else { |
| _new_event = THPEvent_new( |
| static_cast<c10::DeviceType>(self->device_type), |
| c10::EventFlag::PYTORCH_DEFAULT); |
| } |
| auto new_event = (THPEvent*)_new_event; |
| TORCH_CHECK(new_event, "event must not be null"); |
| new_event->event.record(c10::Stream::unpack3( |
| self->stream_id, |
| self->device_index, |
| static_cast<c10::DeviceType>(self->device_type))); |
| return (PyObject*)new_event; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* THPStream_repr(THPStream* self) { |
| HANDLE_TH_ERRORS |
| return THPUtils_packString( |
| "torch.Stream device_type=" + |
| c10::DeviceTypeName( |
| static_cast<c10::DeviceType>(self->device_type), true) + |
| ", device_index=" + std::to_string(self->device_index) + |
| ", stream_id=" + std::to_string(self->stream_id)); |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static Py_hash_t THPStream_hash(THPStream* self) { |
| return static_cast<long>(at::hash_combine( |
| self->device_type, |
| (at::hash_combine(self->stream_id, self->device_index)))); |
| } |
| |
| static PyObject* THPStream_eq(THPStream* self, THPStream* other) { |
| HANDLE_TH_ERRORS |
| return PyBool_FromLong( |
| (self->stream_id == other->stream_id) && |
| (self->device_index == other->device_index) && |
| (self->device_type == other->device_type)); |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* THPStream_ne(THPStream* self, THPStream* other) { |
| HANDLE_TH_ERRORS |
| return PyBool_FromLong( |
| (self->stream_id != other->stream_id) || |
| (self->device_index != other->device_index) || |
| (self->device_type != other->device_type)); |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* THPStream_richcompare( |
| PyObject* self, |
| PyObject* other, |
| int op) { |
| PyObject* result = NULL; |
| if (other == Py_None) { |
| result = Py_False; |
| } else { |
| switch (op) { |
| case Py_EQ: |
| result = THPStream_eq((THPStream*)self, (THPStream*)other); |
| break; |
| case Py_NE: |
| result = THPStream_ne((THPStream*)self, (THPStream*)other); |
| break; |
| default: |
| result = Py_False; |
| break; |
| } |
| } |
| Py_XINCREF(result); |
| return result; |
| } |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) |
| static struct PyMemberDef THPStream_members[] = { |
| {"stream_id", |
| T_LONGLONG, |
| offsetof(THPStream, stream_id), |
| READONLY, |
| nullptr}, |
| {"device_index", |
| T_LONGLONG, |
| offsetof(THPStream, device_index), |
| READONLY, |
| nullptr}, |
| {"device_type", |
| T_LONGLONG, |
| offsetof(THPStream, device_type), |
| READONLY, |
| nullptr}, |
| {nullptr}}; |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) |
| static struct PyGetSetDef THPStream_properties[] = { |
| {"device", (getter)THPStream_get_device, nullptr, nullptr, nullptr}, |
| {nullptr}}; |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) |
| static PyMethodDef THPStream_methods[] = { |
| {"query", THPStream_query, METH_NOARGS, nullptr}, |
| {"synchronize", THPStream_synchronize, METH_NOARGS, nullptr}, |
| {"wait_event", THPStream_wait_event, METH_O, nullptr}, |
| {"wait_stream", THPStream_wait_stream, METH_O, nullptr}, |
| {"record_event", |
| castPyCFunctionWithKeywords(THPStream_record_event), |
| METH_VARARGS | METH_KEYWORDS, |
| nullptr}, |
| {"__eq__", (PyCFunction)THPStream_eq, METH_O, nullptr}, |
| {nullptr}}; |
| |
| PyTypeObject THPStreamType = { |
| PyVarObject_HEAD_INIT(nullptr, 0) "torch.Stream", /* tp_name */ |
| sizeof(THPStream), /* tp_basicsize */ |
| 0, /* tp_itemsize */ |
| (destructor)THPStream_dealloc, /* tp_dealloc */ |
| 0, /* tp_vectorcall_offset */ |
| nullptr, /* tp_getattr */ |
| nullptr, /* tp_setattr */ |
| nullptr, /* tp_reserved */ |
| (reprfunc)THPStream_repr, /* tp_repr */ |
| nullptr, /* tp_as_number */ |
| nullptr, /* tp_as_sequence */ |
| nullptr, /* tp_as_mapping */ |
| (hashfunc)THPStream_hash, /* tp_hash */ |
| nullptr, /* tp_call */ |
| nullptr, /* tp_str */ |
| nullptr, /* tp_getattro */ |
| nullptr, /* tp_setattro */ |
| nullptr, /* tp_as_buffer */ |
| // NOLINTNEXTLINE(misc-redundant-expression) |
| Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ |
| nullptr, /* tp_doc */ |
| nullptr, /* tp_traverse */ |
| nullptr, /* tp_clear */ |
| THPStream_richcompare, /* tp_richcompare */ |
| 0, /* tp_weaklistoffset */ |
| nullptr, /* tp_iter */ |
| nullptr, /* tp_iternext */ |
| THPStream_methods, /* tp_methods */ |
| THPStream_members, /* tp_members */ |
| THPStream_properties, /* tp_getset */ |
| nullptr, /* tp_base */ |
| nullptr, /* tp_dict */ |
| nullptr, /* tp_descr_get */ |
| nullptr, /* tp_descr_set */ |
| 0, /* tp_dictoffset */ |
| nullptr, /* tp_init */ |
| nullptr, /* tp_alloc */ |
| THPStream_pynew, /* tp_new */ |
| }; |
| |
| void THPStream_init(PyObject* module) { |
| THPStreamClass = &THPStreamType; |
| Py_SET_TYPE(&THPStreamType, &PyType_Type); |
| if (PyType_Ready(&THPStreamType) < 0) { |
| throw python_error(); |
| } |
| Py_INCREF(&THPStreamType); |
| if (PyModule_AddObject(module, "Stream", (PyObject*)&THPStreamType) < 0) { |
| throw python_error(); |
| } |
| } |