export current_blas_handle
diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp
index 124e13feb..e3ab79d 100644
--- a/torch/csrc/Module.cpp
+++ b/torch/csrc/Module.cpp
@@ -454,6 +454,7 @@
extern PyObject * THCPModule_getDevice_wrap(PyObject *self);
extern PyObject * THCPModule_getDeviceCount_wrap(PyObject *self);
extern PyObject * THCPModule_getCurrentStream_wrap(PyObject *self);
+extern PyObject * THCPModule_getCurrentBlasHandle_wrap(PyObject *self);
extern PyObject * THCPModule_setStream_wrap(PyObject *self, PyObject *stream);
extern PyObject * THCPModule_getDriverVersion(PyObject *self);
extern PyObject * THCPModule_isDriverSufficient(PyObject *self);
@@ -485,6 +486,7 @@
{"_cuda_getDevice", (PyCFunction)THCPModule_getDevice_wrap, METH_NOARGS, NULL},
{"_cuda_getDeviceCount", (PyCFunction)THCPModule_getDeviceCount_wrap, METH_NOARGS, NULL},
{"_cuda_getCurrentStream", (PyCFunction)THCPModule_getCurrentStream_wrap, METH_NOARGS, NULL},
+ {"_cuda_getCurrentBlasHandle", (PyCFunction)THCPModule_getCurrentBlasHandle_wrap, METH_NOARGS, NULL},
{"_cuda_setStream", (PyCFunction)THCPModule_setStream_wrap, METH_O, NULL},
{"_cuda_isDriverSufficient", (PyCFunction)THCPModule_isDriverSufficient, METH_NOARGS, NULL},
{"_cuda_getDriverVersion", (PyCFunction)THCPModule_getDriverVersion, METH_NOARGS, NULL},
diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp
index 17f715a..5198089 100644
--- a/torch/csrc/cuda/Module.cpp
+++ b/torch/csrc/cuda/Module.cpp
@@ -324,3 +324,11 @@
ncclGetUniqueId(&uniqueId);
}
#endif
+
+PyObject * THCPModule_getCurrentBlasHandle_wrap(PyObject *self)
+{
+ HANDLE_TH_ERRORS
+ cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
+ return PyLong_FromVoidPtr(handle);
+ END_HANDLE_TH_ERRORS
+}
diff --git a/torch/csrc/cuda/Module.h b/torch/csrc/cuda/Module.h
index 5309643..f5929d3 100644
--- a/torch/csrc/cuda/Module.h
+++ b/torch/csrc/cuda/Module.h
@@ -9,6 +9,7 @@
PyObject * THCPModule_setDevice_wrap(PyObject *self, PyObject *arg);
PyObject * THCPModule_getDriverVersion(PyObject *self);
PyObject * THCPModule_isDriverSufficient(PyObject *self);
+PyObject * THCPModule_getCurrentBlasHandle_wrap(PyObject *self);
#endif
#endif
diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py
index 401e38c..f65f903 100644
--- a/torch/cuda/__init__.py
+++ b/torch/cuda/__init__.py
@@ -215,6 +215,11 @@
return torch.cuda.Stream(_cdata=torch._C._cuda_getCurrentStream())
+def current_blas_handle():
+ """Returns cublasHandle_t pointer to current cuBLAS handle"""
+ return torch._C._cuda_getCurrentBlasHandle()
+
+
def _host_allocator():
_lazy_init()
return torch._C._cuda_cudaHostAllocator()