Add has_lapack flag (#11024)
Summary:
Currently our `skipIfLapack` has uses a try-catch block and regex match the error message. It is highly unreliable. This PR adds `hasLAPACK` and `hasMAGMA` on ATen context, and expose the flags to python.
Also fixes refcounting bug with `PyModule_AddObject`. The method steals reference, but we didn't `Py_INCREF` in some places before calling it with `Py_True` or `Py_False`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11024
Differential Revision: D9564898
Pulled By: SsnL
fbshipit-source-id: f46862ec3558d7e0058ef48991cd9c720cb317e2
diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp
index a2c3fb4..5b420d8 100644
--- a/aten/src/ATen/Context.cpp
+++ b/aten/src/ATen/Context.cpp
@@ -11,6 +11,8 @@
#include "ATen/CPUGenerator.h"
#include "ATen/RegisterCPU.h"
+#include "TH/TH.h" // for USE_LAPACK
+
#ifdef USE_SSE3
#include <pmmintrin.h>
#endif
@@ -80,6 +82,14 @@
#endif
}
+bool Context::hasLAPACK() const {
+#ifdef USE_LAPACK
+ return true;
+#else
+ return false;
+#endif
+}
+
bool Context::setFlushDenormal(bool on) {
#ifdef USE_SSE3
// Setting flush-to-zero (FTZ) flag
diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h
index 5584963..bab1fa5 100644
--- a/aten/src/ATen/Context.h
+++ b/aten/src/ATen/Context.h
@@ -50,6 +50,10 @@
return *generator;
}
bool hasMKL() const;
+ bool hasLAPACK() const;
+ bool hasMAGMA() const {
+ return detail::getCUDAHooks().hasMAGMA();
+ }
bool hasCUDA() const {
return detail::getCUDAHooks().hasCUDA();
}
@@ -158,6 +162,14 @@
return globalContext().hasMKL();
}
+static inline bool hasLAPACK() {
+ return globalContext().hasLAPACK();
+}
+
+static inline bool hasMAGMA() {
+ return globalContext().hasMAGMA();
+}
+
static inline int64_t current_device() {
return globalContext().current_device();
}
diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp
index 7d73faf..570a375 100644
--- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp
+++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp
@@ -69,7 +69,7 @@
// let's not if we don't need to!)
std::unique_ptr<THCState, void (*)(THCState*)> CUDAHooks::initCUDA() const {
THCState* thc_state = THCState_alloc();
-
+
THCudaInit(thc_state);
return std::unique_ptr<THCState, void (*)(THCState*)>(
thc_state, [](THCState* p) {
@@ -92,6 +92,14 @@
return true;
}
+bool CUDAHooks::hasMAGMA() const {
+#ifdef USE_MAGMA
+ return true;
+#else
+ return false;
+#endif
+}
+
bool CUDAHooks::hasCuDNN() const {
return AT_CUDNN_ENABLED();
}
diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.h b/aten/src/ATen/cuda/detail/CUDAHooks.h
index 766ab62..491adfc 100644
--- a/aten/src/ATen/cuda/detail/CUDAHooks.h
+++ b/aten/src/ATen/cuda/detail/CUDAHooks.h
@@ -13,6 +13,7 @@
std::unique_ptr<THCState, void(*)(THCState*)> initCUDA() const override;
std::unique_ptr<Generator> initCUDAGenerator(Context*) const override;
bool hasCUDA() const override;
+ bool hasMAGMA() const override;
bool hasCuDNN() const override;
int64_t current_device() const override;
Allocator* getPinnedMemoryAllocator() const override;
diff --git a/aten/src/ATen/detail/CUDAHooksInterface.h b/aten/src/ATen/detail/CUDAHooksInterface.h
index 6b2e87c..cccf6dc 100644
--- a/aten/src/ATen/detail/CUDAHooksInterface.h
+++ b/aten/src/ATen/detail/CUDAHooksInterface.h
@@ -65,6 +65,10 @@
return false;
}
+ virtual bool hasMAGMA() const {
+ return false;
+ }
+
virtual bool hasCuDNN() const {
return false;
}
diff --git a/test/common.py b/test/common.py
index 545ba4f..e7d6940 100644
--- a/test/common.py
+++ b/test/common.py
@@ -112,12 +112,10 @@
def skipIfNoLapack(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
- try:
+ if not torch._C.has_lapack:
+ raise unittest.SkipTest('PyTorch compiled without Lapack')
+ else:
fn(*args, **kwargs)
- except Exception as e:
- if 'Lapack library not found' in repr(e):
- raise unittest.SkipTest('Compiled without Lapack')
- raise
return wrapper
diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp
index af367c3..e17997e 100644
--- a/torch/csrc/Module.cpp
+++ b/torch/csrc/Module.cpp
@@ -584,13 +584,20 @@
ASSERT_TRUE(THCPStream_init(module));
#endif
+ auto set_module_attr = [&](const char* name, PyObject* v, bool incref = true) {
+ // PyModule_AddObject steals reference
+ if (incref) {
+ Py_INCREF(v);
+ }
+ return PyModule_AddObject(module, name, v) == 0;
+ };
+
#ifdef USE_CUDNN
PyObject *has_cudnn = Py_True;
#else
PyObject *has_cudnn = Py_False;
#endif
- Py_INCREF(has_cudnn);
- ASSERT_TRUE(PyModule_AddObject(module, "has_cudnn", has_cudnn) == 0);
+ ASSERT_TRUE(set_module_attr("has_cudnn", has_cudnn));
#ifdef USE_DISTRIBUTED_MW
// See comment on CUDA objects
@@ -611,19 +618,20 @@
// Set ATen warnings to issue Python warnings
at::Warning::set_warning_handler(&warning_handler);
- ASSERT_TRUE(PyModule_AddObject(module, "has_mkl", at::hasMKL() ? Py_True : Py_False) == 0);
+ ASSERT_TRUE(set_module_attr("has_mkl", at::hasMKL() ? Py_True : Py_False));
+ ASSERT_TRUE(set_module_attr("has_lapack", at::hasLAPACK() ? Py_True : Py_False));
#ifdef _GLIBCXX_USE_CXX11_ABI
- ASSERT_TRUE(PyModule_AddObject(module, "_GLIBCXX_USE_CXX11_ABI",
- _GLIBCXX_USE_CXX11_ABI ? Py_True : Py_False) == 0);
+ ASSERT_TRUE(set_module_attr("_GLIBCXX_USE_CXX11_ABI", _GLIBCXX_USE_CXX11_ABI ? Py_True : Py_False));
#else
- ASSERT_TRUE(PyModule_AddObject(module, "_GLIBCXX_USE_CXX11_ABI", Py_False) == 0);
+ ASSERT_TRUE(set_module_attr("_GLIBCXX_USE_CXX11_ABI", Py_False));
#endif
auto& defaultGenerator = at::globalContext().defaultGenerator(at::kCPU);
THPDefaultGenerator = (THPGenerator*)THPGenerator_NewWithGenerator(
defaultGenerator);
- ASSERT_TRUE(PyModule_AddObject(module, "default_generator", (PyObject*)THPDefaultGenerator) == 0);
+ // This reference is meant to be given away, so no need to incref here.
+ ASSERT_TRUE(set_module_attr("default_generator", (PyObject*)THPDefaultGenerator, /* incref= */ false));
#ifdef USE_NUMPY
if (_import_array() < 0) return NULL;
diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp
index a4fcc6c..8fd95ed 100644
--- a/torch/csrc/cuda/Module.cpp
+++ b/torch/csrc/cuda/Module.cpp
@@ -333,16 +333,15 @@
THCPCharStorage_postInit(m);
THCPByteStorage_postInit(m);
-#ifdef USE_MAGMA
- THCMagma_init(state);
- bool has_magma = true;
-#else
- bool has_magma = false;
-#endif
+ bool has_magma = at::hasMAGMA();
+ if (has_magma) {
+ THCMagma_init(state);
+ }
bool has_half = true;
auto set_module_attr = [&](const char* name, PyObject* v) {
+ // PyObject_SetAttrString doesn't steal reference. So no need to incref.
if (PyObject_SetAttrString(m, name, v) < 0) {
throw python_error();
}