Add torch.cuda.get_device_name function (#2540)
diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp
index 96c37de..8ab4484 100644
--- a/torch/csrc/Module.cpp
+++ b/torch/csrc/Module.cpp
@@ -534,6 +534,7 @@
extern PyObject * THCPModule_setDevice_wrap(PyObject *self, PyObject *arg);
extern PyObject * THCPModule_getDevice_wrap(PyObject *self);
extern PyObject * THCPModule_getDeviceCount_wrap(PyObject *self);
+extern PyObject * THCPModule_getDeviceName_wrap(PyObject *self, PyObject *arg);
extern PyObject * THCPModule_getCurrentStream_wrap(PyObject *self);
extern PyObject * THCPModule_getCurrentBlasHandle_wrap(PyObject *self);
extern PyObject * THCPModule_setStream_wrap(PyObject *self, PyObject *stream);
@@ -567,6 +568,7 @@
{"_cuda_setDevice", (PyCFunction)THCPModule_setDevice_wrap, METH_O, NULL},
{"_cuda_getDevice", (PyCFunction)THCPModule_getDevice_wrap, METH_NOARGS, NULL},
{"_cuda_getDeviceCount", (PyCFunction)THCPModule_getDeviceCount_wrap, METH_NOARGS, NULL},
+ {"_cuda_getDeviceName", (PyCFunction)THCPModule_getDeviceName_wrap, METH_O, NULL},
{"_cuda_getCurrentStream", (PyCFunction)THCPModule_getCurrentStream_wrap, METH_NOARGS, NULL},
{"_cuda_getCurrentBlasHandle", (PyCFunction)THCPModule_getCurrentBlasHandle_wrap, METH_NOARGS, NULL},
{"_cuda_setStream", (PyCFunction)THCPModule_setStream_wrap, METH_O, NULL},
diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp
index 19c5174..367a7a0 100644
--- a/torch/csrc/cuda/Module.cpp
+++ b/torch/csrc/cuda/Module.cpp
@@ -13,6 +13,7 @@
#include "THCP.h"
+#include "torch/csrc/utils/python_strings.h"
#include "ModuleSparse.cpp"
THCState *state;
@@ -120,6 +121,18 @@
END_HANDLE_TH_ERRORS
}
+PyObject * THCPModule_getDeviceName_wrap(PyObject *self, PyObject *arg)
+{
+ HANDLE_TH_ERRORS
+ THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to getDeviceName");
+ long device = THPUtils_unpackLong(arg);
+
+ cudaDeviceProp prop;
+ THCudaCheck(cudaGetDeviceProperties(&prop, device));
+ return THPUtils_packString(prop.name);
+ END_HANDLE_TH_ERRORS
+}
+
PyObject * THCPModule_getCurrentStream_wrap(PyObject *self)
{
HANDLE_TH_ERRORS
diff --git a/torch/csrc/cuda/Module.h b/torch/csrc/cuda/Module.h
index f5929d3..4f5fe3f 100644
--- a/torch/csrc/cuda/Module.h
+++ b/torch/csrc/cuda/Module.h
@@ -7,6 +7,7 @@
void THCPModule_setDevice(int idx);
PyObject * THCPModule_getDevice_wrap(PyObject *self);
PyObject * THCPModule_setDevice_wrap(PyObject *self, PyObject *arg);
+PyObject * THCPModule_getDeviceName_wrap(PyObject *self, PyObject *arg);
PyObject * THCPModule_getDriverVersion(PyObject *self);
PyObject * THCPModule_isDriverSufficient(PyObject *self);
PyObject * THCPModule_getCurrentBlasHandle_wrap(PyObject *self);
diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py
index f3b8622..40b9f8f6 100644
--- a/torch/cuda/__init__.py
+++ b/torch/cuda/__init__.py
@@ -162,6 +162,17 @@
torch._C._cuda_setDevice(device)
+def get_device_name(device):
+ """Gets the name of a device.
+
+ Arguments:
+ device (int): device for which to return the name. This function is a
+ no-op if this argument is negative.
+ """
+ if device >= 0:
+ return torch._C._cuda_getDeviceName(device)
+
+
@contextlib.contextmanager
def stream(stream):
"""Context-manager that selects a given stream.