blob: 954681d478026f3e7181175f2fed463598719065 [file] [log] [blame]
#include "Device.h"
#include <cstring>
#include <structmember.h>
#include <sstream>
#include "torch/csrc/Exceptions.h"
#include "torch/csrc/utils/object_ptr.h"
#include "torch/csrc/utils/python_arg_parser.h"
#include "torch/csrc/utils/python_strings.h"
PyObject *THPDevice_New(const torch::Device& device)
{
auto type = (PyTypeObject*)&THPDeviceType;
auto self = THPObjectPtr{type->tp_alloc(type, 0)};
if (!self) throw python_error();
auto self_ = reinterpret_cast<THPDevice*>(self.get());
self_->device = device;
return self.release();
}
static const char* cuda_str = "cuda";
static const char* cpu_str = "cpu";
static inline const char* deviceTypeString(torch::DeviceType device_type) {
switch (device_type) {
case torch::DeviceType::CUDA:
return cuda_str;
case torch::DeviceType::CPU:
return cpu_str;
default:
throw std::runtime_error("unexpected device type");
}
}
PyObject *THPDevice_repr(THPDevice *self)
{
std::ostringstream oss;
oss << "device(type=\'" << deviceTypeString(self->device.type) << "\'";
if (!self->device.is_default) {
oss << ", index=" << self->device.index;
}
oss << ")";
return THPUtils_packString(oss.str().c_str());
}
PyObject *THPDevice_str(THPDevice*self)
{
std::ostringstream oss;
if (!self->device.is_default) {
oss << deviceTypeString(self->device.type) << ":" << self->device.index;
} else {
oss << deviceTypeString(self->device.type);
}
return THPUtils_packString(oss.str().c_str());
}
PyObject *THPDevice_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs)
{
HANDLE_TH_ERRORS
static torch::PythonArgParser parser({
"Device(Device device)",
"Device(String type, int64_t? index=-1)"
});
torch::ParsedArgs<2> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
auto device = r.device(0);
return THPDevice_New(device);
} else if (r.idx == 1) {
auto as_device = r.device(0); // this works, because device can take strings
auto device_type = r.string(0);
if (!as_device.is_default) {
throw std::runtime_error("type (string) must not include an index because index "
"was passed explicitly: " + device_type);
}
auto is_default = r.isNone(1);
auto device_index = r.toInt64WithDefault(1, -1);
// make sure this is constructible
auto device = torch::Device(as_device.type, device_index, is_default);
return THPDevice_New(device);
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject *THPDevice_type(THPDevice *self)
{
HANDLE_TH_ERRORS
return THPUtils_packString(deviceTypeString(self->device.type));
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject *THPDevice_index(THPDevice *self)
{
HANDLE_TH_ERRORS
if (self->device.is_default) {
Py_RETURN_NONE;
} else {
return THPUtils_packInt64(self->device.index);
}
END_HANDLE_TH_ERRORS
}
PyObject *THPDevice_rc(PyObject *a, PyObject *b, int op) {
HANDLE_TH_ERRORS
if (!THPDevice_Check(a) || !THPDevice_Check(b)) {
// Py_RETURN_NOTIMPLEMENTED not in python 2.
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
THPDevice *da = reinterpret_cast<THPDevice*>(a);
THPDevice *db = reinterpret_cast<THPDevice*>(b);
switch(op) {
case Py_EQ:
if (da->device == db->device) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
case Py_NE:
if (da->device == db->device) {
Py_RETURN_FALSE;
} else {
Py_RETURN_TRUE;
}
case Py_LT:
case Py_LE:
case Py_GT:
case Py_GE:
throw torch::TypeError("comparison not implemented");
default:
throw torch::TypeError("unexpected comparison op");
}
END_HANDLE_TH_ERRORS
}
typedef PyObject *(*getter)(PyObject *, void *);
static struct PyGetSetDef THPDevice_properties[] = {
{"type", (getter)THPDevice_type, nullptr, nullptr, nullptr},
{"index", (getter)THPDevice_index, nullptr, nullptr, nullptr},
{nullptr}
};
PyTypeObject THPDeviceType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"torch.Device", /* tp_name */
sizeof(THPDevice), /* tp_basicsize */
0, /* tp_itemsize */
0, /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_reserved */
(reprfunc)THPDevice_repr, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
(reprfunc)THPDevice_str, /* tp_str */
0, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT, /* tp_flags */
nullptr, /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
(richcmpfunc)THPDevice_rc, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
0, /* tp_methods */
0, /* tp_members */
THPDevice_properties, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
0, /* tp_init */
0, /* tp_alloc */
THPDevice_pynew, /* tp_new */
};
void THPDevice_init(PyObject *module)
{
if (PyType_Ready(&THPDeviceType) < 0) {
throw python_error();
}
Py_INCREF(&THPDeviceType);
if (PyModule_AddObject(module, "device", (PyObject *)&THPDeviceType) != 0) {
throw python_error();
}
}