| #include <pybind11/pybind11.h> |
| #include <torch/csrc/Device.h> |
| #include <torch/csrc/THP.h> |
| #include <torch/csrc/utils/pybind.h> |
| #include <torch/csrc/utils/python_arg_parser.h> |
| |
| #include <structmember.h> |
| |
| PyTypeObject* THPStreamClass = nullptr; |
| |
| static PyObject* THPStream_pynew( |
| PyTypeObject* type, |
| PyObject* args, |
| PyObject* kwargs) { |
| HANDLE_TH_ERRORS |
| uint64_t cdata = 0; |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,clang-diagnostic-writable-strings) |
| static char* kwlist[] = {"_cdata", nullptr}; |
| if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|K", kwlist, &cdata)) { |
| return nullptr; |
| } |
| |
| THPObjectPtr ptr(type->tp_alloc(type, 0)); |
| if (!ptr) { |
| return nullptr; |
| } |
| |
| THPStream* self = (THPStream*)ptr.get(); |
| self->cdata = cdata; |
| return (PyObject*)ptr.release(); |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static void THPStream_dealloc(THPStream* self) { |
| Py_TYPE(self)->tp_free((PyObject*)self); |
| } |
| |
| static PyObject* THPStream_get_device(THPStream* self, void* unused) { |
| HANDLE_TH_ERRORS |
| return THPDevice_New(c10::Stream::unpack(self->cdata).device()); |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* THPStream_eq(THPStream* self, THPStream* other) { |
| HANDLE_TH_ERRORS |
| return PyBool_FromLong(self->cdata == other->cdata); |
| END_HANDLE_TH_ERRORS |
| } |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) |
| static struct PyMemberDef THPStream_members[] = { |
| {(char*)"_cdata", |
| T_ULONGLONG, |
| offsetof(THPStream, cdata), |
| READONLY, |
| nullptr}, |
| {nullptr}}; |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) |
| static struct PyGetSetDef THPStream_properties[] = { |
| {"device", (getter)THPStream_get_device, nullptr, nullptr, nullptr}, |
| {nullptr}}; |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) |
| static PyMethodDef THPStream_methods[] = { |
| {(char*)"__eq__", (PyCFunction)THPStream_eq, METH_O, nullptr}, |
| {nullptr}}; |
| |
| PyTypeObject THPStreamType = { |
| PyVarObject_HEAD_INIT(nullptr, 0) "torch.Stream", /* tp_name */ |
| sizeof(THPStream), /* tp_basicsize */ |
| 0, /* tp_itemsize */ |
| (destructor)THPStream_dealloc, /* tp_dealloc */ |
| 0, /* tp_vectorcall_offset */ |
| nullptr, /* tp_getattr */ |
| nullptr, /* tp_setattr */ |
| nullptr, /* tp_reserved */ |
| nullptr, /* tp_repr */ |
| nullptr, /* tp_as_number */ |
| nullptr, /* tp_as_sequence */ |
| nullptr, /* tp_as_mapping */ |
| nullptr, /* tp_hash */ |
| nullptr, /* tp_call */ |
| nullptr, /* tp_str */ |
| nullptr, /* tp_getattro */ |
| nullptr, /* tp_setattro */ |
| nullptr, /* tp_as_buffer */ |
| Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ |
| nullptr, /* tp_doc */ |
| nullptr, /* tp_traverse */ |
| nullptr, /* tp_clear */ |
| nullptr, /* tp_richcompare */ |
| 0, /* tp_weaklistoffset */ |
| nullptr, /* tp_iter */ |
| nullptr, /* tp_iternext */ |
| THPStream_methods, /* tp_methods */ |
| THPStream_members, /* tp_members */ |
| THPStream_properties, /* tp_getset */ |
| nullptr, /* tp_base */ |
| nullptr, /* tp_dict */ |
| nullptr, /* tp_descr_get */ |
| nullptr, /* tp_descr_set */ |
| 0, /* tp_dictoffset */ |
| nullptr, /* tp_init */ |
| nullptr, /* tp_alloc */ |
| THPStream_pynew, /* tp_new */ |
| }; |
| |
| void THPStream_init(PyObject* module) { |
| THPStreamClass = &THPStreamType; |
| Py_SET_TYPE(&THPStreamType, &PyType_Type); |
| if (PyType_Ready(&THPStreamType) < 0) { |
| throw python_error(); |
| } |
| Py_INCREF(&THPStreamType); |
| if (PyModule_AddObject(module, "Stream", (PyObject*)&THPStreamType) < 0) { |
| throw python_error(); |
| } |
| } |